1 //===- Dialect.cpp - Dialect implementation -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Dialect.h" 10 #include "mlir/IR/BuiltinDialect.h" 11 #include "mlir/IR/Diagnostics.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/DialectInterface.h" 14 #include "mlir/IR/MLIRContext.h" 15 #include "mlir/IR/Operation.h" 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/Twine.h" 18 #include "llvm/Support/Debug.h" 19 #include "llvm/Support/ManagedStatic.h" 20 #include "llvm/Support/Regex.h" 21 22 #define DEBUG_TYPE "dialect" 23 24 using namespace mlir; 25 using namespace detail; 26 27 //===----------------------------------------------------------------------===// 28 // Dialect 29 //===----------------------------------------------------------------------===// 30 31 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id) 32 : name(name), dialectID(id), context(context) { 33 assert(isValidNamespace(name) && "invalid dialect namespace"); 34 } 35 36 Dialect::~Dialect() = default; 37 38 /// Verify an attribute from this dialect on the argument at 'argIndex' for 39 /// the region at 'regionIndex' on the given operation. Returns failure if 40 /// the verification failed, success otherwise. This hook may optionally be 41 /// invoked from any operation containing a region. 42 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned, 43 NamedAttribute) { 44 return success(); 45 } 46 47 /// Verify an attribute from this dialect on the result at 'resultIndex' for 48 /// the region at 'regionIndex' on the given operation. Returns failure if 49 /// the verification failed, success otherwise. This hook may optionally be 50 /// invoked from any operation containing a region. 51 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned, 52 unsigned, NamedAttribute) { 53 return success(); 54 } 55 56 /// Parse an attribute registered to this dialect. 57 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const { 58 parser.emitError(parser.getNameLoc()) 59 << "dialect '" << getNamespace() 60 << "' provides no attribute parsing hook"; 61 return Attribute(); 62 } 63 64 /// Parse a type registered to this dialect. 65 Type Dialect::parseType(DialectAsmParser &parser) const { 66 // If this dialect allows unknown types, then represent this with OpaqueType. 67 if (allowsUnknownTypes()) { 68 StringAttr ns = StringAttr::get(getContext(), getNamespace()); 69 return OpaqueType::get(ns, parser.getFullSymbolSpec()); 70 } 71 72 parser.emitError(parser.getNameLoc()) 73 << "dialect '" << getNamespace() << "' provides no type parsing hook"; 74 return Type(); 75 } 76 77 Optional<Dialect::ParseOpHook> 78 Dialect::getParseOperationHook(StringRef opName) const { 79 return None; 80 } 81 82 llvm::unique_function<void(Operation *, OpAsmPrinter &printer)> 83 Dialect::getOperationPrinter(Operation *op) const { 84 assert(op->getDialect() == this && 85 "Dialect hook invoked on non-dialect owned operation"); 86 return nullptr; 87 } 88 89 /// Utility function that returns if the given string is a valid dialect 90 /// namespace 91 bool Dialect::isValidNamespace(StringRef str) { 92 llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$"); 93 return dialectNameRegex.match(str); 94 } 95 96 /// Register a set of dialect interfaces with this dialect instance. 97 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) { 98 auto it = registeredInterfaces.try_emplace(interface->getID(), 99 std::move(interface)); 100 (void)it; 101 LLVM_DEBUG({ 102 if (!it.second) { 103 llvm::dbgs() << "[" DEBUG_TYPE 104 "] repeated interface registration for dialect " 105 << getNamespace(); 106 } 107 }); 108 } 109 110 //===----------------------------------------------------------------------===// 111 // Dialect Interface 112 //===----------------------------------------------------------------------===// 113 114 DialectInterface::~DialectInterface() = default; 115 116 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase( 117 MLIRContext *ctx, TypeID interfaceKind) { 118 for (auto *dialect : ctx->getLoadedDialects()) { 119 if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) { 120 interfaces.insert(interface); 121 orderedInterfaces.push_back(interface); 122 } 123 } 124 } 125 126 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() = default; 127 128 /// Get the interface for the dialect of given operation, or null if one 129 /// is not registered. 130 const DialectInterface * 131 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const { 132 return getInterfaceFor(op->getDialect()); 133 } 134 135 //===----------------------------------------------------------------------===// 136 // DialectExtension 137 //===----------------------------------------------------------------------===// 138 139 DialectExtensionBase::~DialectExtensionBase() = default; 140 141 //===----------------------------------------------------------------------===// 142 // DialectRegistry 143 //===----------------------------------------------------------------------===// 144 145 DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); } 146 147 DialectAllocatorFunctionRef 148 DialectRegistry::getDialectAllocator(StringRef name) const { 149 auto it = registry.find(name.str()); 150 if (it == registry.end()) 151 return nullptr; 152 return it->second.second; 153 } 154 155 void DialectRegistry::insert(TypeID typeID, StringRef name, 156 const DialectAllocatorFunction &ctor) { 157 auto inserted = registry.insert( 158 std::make_pair(std::string(name), std::make_pair(typeID, ctor))); 159 if (!inserted.second && inserted.first->second.first != typeID) { 160 llvm::report_fatal_error( 161 "Trying to register different dialects for the same namespace: " + 162 name); 163 } 164 } 165 166 void DialectRegistry::applyExtensions(Dialect *dialect) const { 167 MLIRContext *ctx = dialect->getContext(); 168 StringRef dialectName = dialect->getNamespace(); 169 170 // Functor used to try to apply the given extension. 171 auto applyExtension = [&](const DialectExtensionBase &extension) { 172 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects(); 173 174 // Handle the simple case of a single dialect name. In this case, the 175 // required dialect should be the current dialect. 176 if (dialectNames.size() == 1) { 177 if (dialectNames.front() == dialectName) 178 extension.apply(ctx, dialect); 179 return; 180 } 181 182 // Otherwise, check to see if this extension requires this dialect. 183 const StringRef *nameIt = llvm::find(dialectNames, dialectName); 184 if (nameIt == dialectNames.end()) 185 return; 186 187 // If it does, ensure that all of the other required dialects have been 188 // loaded. 189 SmallVector<Dialect *> requiredDialects; 190 requiredDialects.reserve(dialectNames.size()); 191 for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e; 192 ++it) { 193 // The current dialect is known to be loaded. 194 if (it == nameIt) { 195 requiredDialects.push_back(dialect); 196 continue; 197 } 198 // Otherwise, check if it is loaded. 199 Dialect *loadedDialect = ctx->getLoadedDialect(*it); 200 if (!loadedDialect) 201 return; 202 requiredDialects.push_back(loadedDialect); 203 } 204 extension.apply(ctx, requiredDialects); 205 }; 206 207 for (const auto &extension : extensions) 208 applyExtension(*extension); 209 } 210 211 void DialectRegistry::applyExtensions(MLIRContext *ctx) const { 212 // Functor used to try to apply the given extension. 213 auto applyExtension = [&](const DialectExtensionBase &extension) { 214 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects(); 215 216 // Check to see if all of the dialects for this extension are loaded. 217 SmallVector<Dialect *> requiredDialects; 218 requiredDialects.reserve(dialectNames.size()); 219 for (StringRef dialectName : dialectNames) { 220 Dialect *loadedDialect = ctx->getLoadedDialect(dialectName); 221 if (!loadedDialect) 222 return; 223 requiredDialects.push_back(loadedDialect); 224 } 225 extension.apply(ctx, requiredDialects); 226 }; 227 228 for (const auto &extension : extensions) 229 applyExtension(*extension); 230 } 231