1 //===- Dialect.cpp - Dialect implementation -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Dialect.h" 10 #include "mlir/IR/BuiltinDialect.h" 11 #include "mlir/IR/Diagnostics.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/DialectInterface.h" 14 #include "mlir/IR/DialectRegistry.h" 15 #include "mlir/IR/ExtensibleDialect.h" 16 #include "mlir/IR/MLIRContext.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Support/TypeID.h" 19 #include "llvm/ADT/MapVector.h" 20 #include "llvm/ADT/SetOperations.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include "llvm/ADT/SmallVectorExtras.h" 23 #include "llvm/ADT/Twine.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/ManagedStatic.h" 26 #include "llvm/Support/Regex.h" 27 #include <memory> 28 29 #define DEBUG_TYPE "dialect" 30 31 using namespace mlir; 32 using namespace detail; 33 34 //===----------------------------------------------------------------------===// 35 // Dialect 36 //===----------------------------------------------------------------------===// 37 38 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id) 39 : name(name), dialectID(id), context(context) { 40 assert(isValidNamespace(name) && "invalid dialect namespace"); 41 } 42 43 Dialect::~Dialect() = default; 44 45 /// Verify an attribute from this dialect on the argument at 'argIndex' for 46 /// the region at 'regionIndex' on the given operation. Returns failure if 47 /// the verification failed, success otherwise. This hook may optionally be 48 /// invoked from any operation containing a region. 49 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned, 50 NamedAttribute) { 51 return success(); 52 } 53 54 /// Verify an attribute from this dialect on the result at 'resultIndex' for 55 /// the region at 'regionIndex' on the given operation. Returns failure if 56 /// the verification failed, success otherwise. This hook may optionally be 57 /// invoked from any operation containing a region. 58 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned, 59 unsigned, NamedAttribute) { 60 return success(); 61 } 62 63 /// Parse an attribute registered to this dialect. 64 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const { 65 parser.emitError(parser.getNameLoc()) 66 << "dialect '" << getNamespace() 67 << "' provides no attribute parsing hook"; 68 return Attribute(); 69 } 70 71 /// Parse a type registered to this dialect. 72 Type Dialect::parseType(DialectAsmParser &parser) const { 73 // If this dialect allows unknown types, then represent this with OpaqueType. 74 if (allowsUnknownTypes()) { 75 StringAttr ns = StringAttr::get(getContext(), getNamespace()); 76 return OpaqueType::get(ns, parser.getFullSymbolSpec()); 77 } 78 79 parser.emitError(parser.getNameLoc()) 80 << "dialect '" << getNamespace() << "' provides no type parsing hook"; 81 return Type(); 82 } 83 84 std::optional<Dialect::ParseOpHook> 85 Dialect::getParseOperationHook(StringRef opName) const { 86 return std::nullopt; 87 } 88 89 llvm::unique_function<void(Operation *, OpAsmPrinter &printer)> 90 Dialect::getOperationPrinter(Operation *op) const { 91 assert(op->getDialect() == this && 92 "Dialect hook invoked on non-dialect owned operation"); 93 return nullptr; 94 } 95 96 /// Utility function that returns if the given string is a valid dialect 97 /// namespace 98 bool Dialect::isValidNamespace(StringRef str) { 99 llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$"); 100 return dialectNameRegex.match(str); 101 } 102 103 /// Register a set of dialect interfaces with this dialect instance. 104 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) { 105 // Handle the case where the models resolve a promised interface. 106 handleAdditionOfUndefinedPromisedInterface(getTypeID(), interface->getID()); 107 108 auto it = registeredInterfaces.try_emplace(interface->getID(), 109 std::move(interface)); 110 (void)it; 111 LLVM_DEBUG({ 112 if (!it.second) { 113 llvm::dbgs() << "[" DEBUG_TYPE 114 "] repeated interface registration for dialect " 115 << getNamespace(); 116 } 117 }); 118 } 119 120 //===----------------------------------------------------------------------===// 121 // Dialect Interface 122 //===----------------------------------------------------------------------===// 123 124 DialectInterface::~DialectInterface() = default; 125 126 MLIRContext *DialectInterface::getContext() const { 127 return dialect->getContext(); 128 } 129 130 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase( 131 MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) { 132 for (auto *dialect : ctx->getLoadedDialects()) { 133 #ifndef NDEBUG 134 dialect->handleUseOfUndefinedPromisedInterface( 135 dialect->getTypeID(), interfaceKind, interfaceName); 136 #endif 137 if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) { 138 interfaces.insert(interface); 139 orderedInterfaces.push_back(interface); 140 } 141 } 142 } 143 144 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() = default; 145 146 /// Get the interface for the dialect of given operation, or null if one 147 /// is not registered. 148 const DialectInterface * 149 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const { 150 return getInterfaceFor(op->getDialect()); 151 } 152 153 //===----------------------------------------------------------------------===// 154 // DialectExtension 155 //===----------------------------------------------------------------------===// 156 157 DialectExtensionBase::~DialectExtensionBase() = default; 158 159 void dialect_extension_detail::handleUseOfUndefinedPromisedInterface( 160 Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID, 161 StringRef interfaceName) { 162 dialect.handleUseOfUndefinedPromisedInterface(interfaceRequestorID, 163 interfaceID, interfaceName); 164 } 165 166 void dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface( 167 Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID) { 168 dialect.handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID, 169 interfaceID); 170 } 171 172 bool dialect_extension_detail::hasPromisedInterface(Dialect &dialect, 173 TypeID interfaceRequestorID, 174 TypeID interfaceID) { 175 return dialect.hasPromisedInterface(interfaceRequestorID, interfaceID); 176 } 177 178 //===----------------------------------------------------------------------===// 179 // DialectRegistry 180 //===----------------------------------------------------------------------===// 181 182 namespace { 183 template <typename Fn> 184 void applyExtensionsFn( 185 Fn &&applyExtension, 186 const llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>> 187 &extensions) { 188 // Note: Additional extensions may be added while applying an extension. 189 // The iterators will be invalidated if extensions are added so we'll keep 190 // a copy of the extensions for ourselves. 191 192 const auto extractExtension = 193 [](const auto &entry) -> DialectExtensionBase * { 194 return entry.second.get(); 195 }; 196 197 auto startIt = extensions.begin(), endIt = extensions.end(); 198 size_t count = 0; 199 while (startIt != endIt) { 200 count += endIt - startIt; 201 202 // Grab the subset of extensions we'll apply in this iteration. 203 const auto subset = 204 llvm::map_to_vector(llvm::make_range(startIt, endIt), extractExtension); 205 206 for (const auto *ext : subset) 207 applyExtension(*ext); 208 209 // Book-keep for the next iteration. 210 startIt = extensions.begin() + count; 211 endIt = extensions.end(); 212 } 213 } 214 } // namespace 215 216 DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); } 217 218 DialectAllocatorFunctionRef 219 DialectRegistry::getDialectAllocator(StringRef name) const { 220 auto it = registry.find(name); 221 if (it == registry.end()) 222 return nullptr; 223 return it->second.second; 224 } 225 226 void DialectRegistry::insert(TypeID typeID, StringRef name, 227 const DialectAllocatorFunction &ctor) { 228 auto inserted = registry.insert( 229 std::make_pair(std::string(name), std::make_pair(typeID, ctor))); 230 if (!inserted.second && inserted.first->second.first != typeID) { 231 llvm::report_fatal_error( 232 "Trying to register different dialects for the same namespace: " + 233 name); 234 } 235 } 236 237 void DialectRegistry::insertDynamic( 238 StringRef name, const DynamicDialectPopulationFunction &ctor) { 239 // This TypeID marks dynamic dialects. We cannot give a TypeID for the 240 // dialect yet, since the TypeID of a dynamic dialect is defined at its 241 // construction. 242 TypeID typeID = TypeID::get<void>(); 243 244 // Create the dialect, and then call ctor, which allocates its components. 245 auto constructor = [nameStr = name.str(), ctor](MLIRContext *ctx) { 246 auto *dynDialect = ctx->getOrLoadDynamicDialect( 247 nameStr, [ctx, ctor](DynamicDialect *dialect) { ctor(ctx, dialect); }); 248 assert(dynDialect && "Dynamic dialect creation unexpectedly failed"); 249 return dynDialect; 250 }; 251 252 insert(typeID, name, constructor); 253 } 254 255 void DialectRegistry::applyExtensions(Dialect *dialect) const { 256 MLIRContext *ctx = dialect->getContext(); 257 StringRef dialectName = dialect->getNamespace(); 258 259 // Functor used to try to apply the given extension. 260 auto applyExtension = [&](const DialectExtensionBase &extension) { 261 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects(); 262 // An empty set is equivalent to always invoke. 263 if (dialectNames.empty()) { 264 extension.apply(ctx, dialect); 265 return; 266 } 267 268 // Handle the simple case of a single dialect name. In this case, the 269 // required dialect should be the current dialect. 270 if (dialectNames.size() == 1) { 271 if (dialectNames.front() == dialectName) 272 extension.apply(ctx, dialect); 273 return; 274 } 275 276 // Otherwise, check to see if this extension requires this dialect. 277 const StringRef *nameIt = llvm::find(dialectNames, dialectName); 278 if (nameIt == dialectNames.end()) 279 return; 280 281 // If it does, ensure that all of the other required dialects have been 282 // loaded. 283 SmallVector<Dialect *> requiredDialects; 284 requiredDialects.reserve(dialectNames.size()); 285 for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e; 286 ++it) { 287 // The current dialect is known to be loaded. 288 if (it == nameIt) { 289 requiredDialects.push_back(dialect); 290 continue; 291 } 292 // Otherwise, check if it is loaded. 293 Dialect *loadedDialect = ctx->getLoadedDialect(*it); 294 if (!loadedDialect) 295 return; 296 requiredDialects.push_back(loadedDialect); 297 } 298 extension.apply(ctx, requiredDialects); 299 }; 300 301 applyExtensionsFn(applyExtension, extensions); 302 } 303 304 void DialectRegistry::applyExtensions(MLIRContext *ctx) const { 305 // Functor used to try to apply the given extension. 306 auto applyExtension = [&](const DialectExtensionBase &extension) { 307 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects(); 308 if (dialectNames.empty()) { 309 auto loadedDialects = ctx->getLoadedDialects(); 310 extension.apply(ctx, loadedDialects); 311 return; 312 } 313 314 // Check to see if all of the dialects for this extension are loaded. 315 SmallVector<Dialect *> requiredDialects; 316 requiredDialects.reserve(dialectNames.size()); 317 for (StringRef dialectName : dialectNames) { 318 Dialect *loadedDialect = ctx->getLoadedDialect(dialectName); 319 if (!loadedDialect) 320 return; 321 requiredDialects.push_back(loadedDialect); 322 } 323 extension.apply(ctx, requiredDialects); 324 }; 325 326 applyExtensionsFn(applyExtension, extensions); 327 } 328 329 bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const { 330 // Check that all extension keys are present in 'rhs'. 331 const auto hasExtension = [&](const auto &key) { 332 return rhs.extensions.contains(key); 333 }; 334 if (!llvm::all_of(make_first_range(extensions), hasExtension)) 335 return false; 336 337 // Check that the current dialects fully overlap with the dialects in 'rhs'. 338 return llvm::all_of( 339 registry, [&](const auto &it) { return rhs.registry.count(it.first); }); 340 } 341