1 //===- Dialect.cpp - Dialect implementation -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Dialect.h" 10 #include "mlir/IR/BuiltinDialect.h" 11 #include "mlir/IR/Diagnostics.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/DialectInterface.h" 14 #include "mlir/IR/ExtensibleDialect.h" 15 #include "mlir/IR/MLIRContext.h" 16 #include "mlir/IR/Operation.h" 17 #include "llvm/ADT/MapVector.h" 18 #include "llvm/ADT/Twine.h" 19 #include "llvm/Support/Debug.h" 20 #include "llvm/Support/ManagedStatic.h" 21 #include "llvm/Support/Regex.h" 22 23 #define DEBUG_TYPE "dialect" 24 25 using namespace mlir; 26 using namespace detail; 27 28 //===----------------------------------------------------------------------===// 29 // Dialect 30 //===----------------------------------------------------------------------===// 31 32 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id) 33 : name(name), dialectID(id), context(context) { 34 assert(isValidNamespace(name) && "invalid dialect namespace"); 35 } 36 37 Dialect::~Dialect() = default; 38 39 /// Verify an attribute from this dialect on the argument at 'argIndex' for 40 /// the region at 'regionIndex' on the given operation. Returns failure if 41 /// the verification failed, success otherwise. This hook may optionally be 42 /// invoked from any operation containing a region. 43 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned, 44 NamedAttribute) { 45 return success(); 46 } 47 48 /// Verify an attribute from this dialect on the result at 'resultIndex' for 49 /// the region at 'regionIndex' on the given operation. Returns failure if 50 /// the verification failed, success otherwise. This hook may optionally be 51 /// invoked from any operation containing a region. 52 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned, 53 unsigned, NamedAttribute) { 54 return success(); 55 } 56 57 /// Parse an attribute registered to this dialect. 58 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const { 59 parser.emitError(parser.getNameLoc()) 60 << "dialect '" << getNamespace() 61 << "' provides no attribute parsing hook"; 62 return Attribute(); 63 } 64 65 /// Parse a type registered to this dialect. 66 Type Dialect::parseType(DialectAsmParser &parser) const { 67 // If this dialect allows unknown types, then represent this with OpaqueType. 68 if (allowsUnknownTypes()) { 69 StringAttr ns = StringAttr::get(getContext(), getNamespace()); 70 return OpaqueType::get(ns, parser.getFullSymbolSpec()); 71 } 72 73 parser.emitError(parser.getNameLoc()) 74 << "dialect '" << getNamespace() << "' provides no type parsing hook"; 75 return Type(); 76 } 77 78 std::optional<Dialect::ParseOpHook> 79 Dialect::getParseOperationHook(StringRef opName) const { 80 return std::nullopt; 81 } 82 83 llvm::unique_function<void(Operation *, OpAsmPrinter &printer)> 84 Dialect::getOperationPrinter(Operation *op) const { 85 assert(op->getDialect() == this && 86 "Dialect hook invoked on non-dialect owned operation"); 87 return nullptr; 88 } 89 90 /// Utility function that returns if the given string is a valid dialect 91 /// namespace 92 bool Dialect::isValidNamespace(StringRef str) { 93 llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$"); 94 return dialectNameRegex.match(str); 95 } 96 97 /// Register a set of dialect interfaces with this dialect instance. 98 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) { 99 // Handle the case where the models resolve a promised interface. 100 handleAdditionOfUndefinedPromisedInterface(getTypeID(), interface->getID()); 101 102 auto it = registeredInterfaces.try_emplace(interface->getID(), 103 std::move(interface)); 104 (void)it; 105 LLVM_DEBUG({ 106 if (!it.second) { 107 llvm::dbgs() << "[" DEBUG_TYPE 108 "] repeated interface registration for dialect " 109 << getNamespace(); 110 } 111 }); 112 } 113 114 //===----------------------------------------------------------------------===// 115 // Dialect Interface 116 //===----------------------------------------------------------------------===// 117 118 DialectInterface::~DialectInterface() = default; 119 120 MLIRContext *DialectInterface::getContext() const { 121 return dialect->getContext(); 122 } 123 124 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase( 125 MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) { 126 for (auto *dialect : ctx->getLoadedDialects()) { 127 #ifndef NDEBUG 128 dialect->handleUseOfUndefinedPromisedInterface( 129 dialect->getTypeID(), interfaceKind, interfaceName); 130 #endif 131 if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) { 132 interfaces.insert(interface); 133 orderedInterfaces.push_back(interface); 134 } 135 } 136 } 137 138 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() = default; 139 140 /// Get the interface for the dialect of given operation, or null if one 141 /// is not registered. 142 const DialectInterface * 143 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const { 144 return getInterfaceFor(op->getDialect()); 145 } 146 147 //===----------------------------------------------------------------------===// 148 // DialectExtension 149 //===----------------------------------------------------------------------===// 150 151 DialectExtensionBase::~DialectExtensionBase() = default; 152 153 void dialect_extension_detail::handleUseOfUndefinedPromisedInterface( 154 Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID, 155 StringRef interfaceName) { 156 dialect.handleUseOfUndefinedPromisedInterface(interfaceRequestorID, 157 interfaceID, interfaceName); 158 } 159 160 void dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface( 161 Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID) { 162 dialect.handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID, 163 interfaceID); 164 } 165 166 bool dialect_extension_detail::hasPromisedInterface(Dialect &dialect, 167 TypeID interfaceRequestorID, 168 TypeID interfaceID) { 169 return dialect.hasPromisedInterface(interfaceRequestorID, interfaceID); 170 } 171 172 //===----------------------------------------------------------------------===// 173 // DialectRegistry 174 //===----------------------------------------------------------------------===// 175 176 DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); } 177 178 DialectAllocatorFunctionRef 179 DialectRegistry::getDialectAllocator(StringRef name) const { 180 auto it = registry.find(name.str()); 181 if (it == registry.end()) 182 return nullptr; 183 return it->second.second; 184 } 185 186 void DialectRegistry::insert(TypeID typeID, StringRef name, 187 const DialectAllocatorFunction &ctor) { 188 auto inserted = registry.insert( 189 std::make_pair(std::string(name), std::make_pair(typeID, ctor))); 190 if (!inserted.second && inserted.first->second.first != typeID) { 191 llvm::report_fatal_error( 192 "Trying to register different dialects for the same namespace: " + 193 name); 194 } 195 } 196 197 void DialectRegistry::insertDynamic( 198 StringRef name, const DynamicDialectPopulationFunction &ctor) { 199 // This TypeID marks dynamic dialects. We cannot give a TypeID for the 200 // dialect yet, since the TypeID of a dynamic dialect is defined at its 201 // construction. 202 TypeID typeID = TypeID::get<void>(); 203 204 // Create the dialect, and then call ctor, which allocates its components. 205 auto constructor = [nameStr = name.str(), ctor](MLIRContext *ctx) { 206 auto *dynDialect = ctx->getOrLoadDynamicDialect( 207 nameStr, [ctx, ctor](DynamicDialect *dialect) { ctor(ctx, dialect); }); 208 assert(dynDialect && "Dynamic dialect creation unexpectedly failed"); 209 return dynDialect; 210 }; 211 212 insert(typeID, name, constructor); 213 } 214 215 void DialectRegistry::applyExtensions(Dialect *dialect) const { 216 MLIRContext *ctx = dialect->getContext(); 217 StringRef dialectName = dialect->getNamespace(); 218 219 // Functor used to try to apply the given extension. 220 auto applyExtension = [&](const DialectExtensionBase &extension) { 221 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects(); 222 // An empty set is equivalent to always invoke. 223 if (dialectNames.empty()) { 224 extension.apply(ctx, dialect); 225 return; 226 } 227 228 // Handle the simple case of a single dialect name. In this case, the 229 // required dialect should be the current dialect. 230 if (dialectNames.size() == 1) { 231 if (dialectNames.front() == dialectName) 232 extension.apply(ctx, dialect); 233 return; 234 } 235 236 // Otherwise, check to see if this extension requires this dialect. 237 const StringRef *nameIt = llvm::find(dialectNames, dialectName); 238 if (nameIt == dialectNames.end()) 239 return; 240 241 // If it does, ensure that all of the other required dialects have been 242 // loaded. 243 SmallVector<Dialect *> requiredDialects; 244 requiredDialects.reserve(dialectNames.size()); 245 for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e; 246 ++it) { 247 // The current dialect is known to be loaded. 248 if (it == nameIt) { 249 requiredDialects.push_back(dialect); 250 continue; 251 } 252 // Otherwise, check if it is loaded. 253 Dialect *loadedDialect = ctx->getLoadedDialect(*it); 254 if (!loadedDialect) 255 return; 256 requiredDialects.push_back(loadedDialect); 257 } 258 extension.apply(ctx, requiredDialects); 259 }; 260 261 // Note: Additional extensions may be added while applying an extension. 262 for (int i = 0; i < static_cast<int>(extensions.size()); ++i) 263 applyExtension(*extensions[i]); 264 } 265 266 void DialectRegistry::applyExtensions(MLIRContext *ctx) const { 267 // Functor used to try to apply the given extension. 268 auto applyExtension = [&](const DialectExtensionBase &extension) { 269 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects(); 270 if (dialectNames.empty()) { 271 auto loadedDialects = ctx->getLoadedDialects(); 272 extension.apply(ctx, loadedDialects); 273 return; 274 } 275 276 // Check to see if all of the dialects for this extension are loaded. 277 SmallVector<Dialect *> requiredDialects; 278 requiredDialects.reserve(dialectNames.size()); 279 for (StringRef dialectName : dialectNames) { 280 Dialect *loadedDialect = ctx->getLoadedDialect(dialectName); 281 if (!loadedDialect) 282 return; 283 requiredDialects.push_back(loadedDialect); 284 } 285 extension.apply(ctx, requiredDialects); 286 }; 287 288 // Note: Additional extensions may be added while applying an extension. 289 for (int i = 0; i < static_cast<int>(extensions.size()); ++i) 290 applyExtension(*extensions[i]); 291 } 292 293 bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const { 294 // Treat any extensions conservatively. 295 if (!extensions.empty()) 296 return false; 297 // Check that the current dialects fully overlap with the dialects in 'rhs'. 298 return llvm::all_of( 299 registry, [&](const auto &it) { return rhs.registry.count(it.first); }); 300 } 301