1 //===- Dialect.cpp - Dialect implementation -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Dialect.h" 10 #include "mlir/IR/BuiltinDialect.h" 11 #include "mlir/IR/Diagnostics.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/DialectInterface.h" 14 #include "mlir/IR/MLIRContext.h" 15 #include "mlir/IR/Operation.h" 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/Twine.h" 18 #include "llvm/Support/Debug.h" 19 #include "llvm/Support/ManagedStatic.h" 20 #include "llvm/Support/Regex.h" 21 22 #define DEBUG_TYPE "dialect" 23 24 using namespace mlir; 25 using namespace detail; 26 27 //===----------------------------------------------------------------------===// 28 // DialectRegistry 29 //===----------------------------------------------------------------------===// 30 31 DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); } 32 33 void DialectRegistry::addDialectInterface( 34 StringRef dialectName, TypeID interfaceTypeID, 35 DialectInterfaceAllocatorFunction allocator) { 36 assert(allocator && "unexpected null interface allocation function"); 37 auto it = registry.find(dialectName.str()); 38 assert(it != registry.end() && 39 "adding an interface for an unregistered dialect"); 40 41 // Bail out if the interface with the given ID is already in the registry for 42 // the given dialect. We expect a small number (dozens) of interfaces so a 43 // linear search is fine here. 44 auto &ifaces = interfaces[it->second.first]; 45 for (const auto &kvp : ifaces.dialectInterfaces) { 46 if (kvp.first == interfaceTypeID) { 47 LLVM_DEBUG(llvm::dbgs() 48 << "[" DEBUG_TYPE 49 "] repeated interface registration for dialect " 50 << dialectName); 51 return; 52 } 53 } 54 55 ifaces.dialectInterfaces.emplace_back(interfaceTypeID, allocator); 56 } 57 58 void DialectRegistry::addObjectInterface( 59 StringRef dialectName, TypeID objectID, TypeID interfaceTypeID, 60 ObjectInterfaceAllocatorFunction allocator) { 61 assert(allocator && "unexpected null interface allocation function"); 62 63 auto it = registry.find(dialectName.str()); 64 assert(it != registry.end() && 65 "adding an interface for an op from an unregistered dialect"); 66 67 auto dialectID = it->second.first; 68 auto &ifaces = interfaces[dialectID]; 69 70 for (const auto &info : ifaces.objectInterfaces) { 71 if (std::get<0>(info) == objectID && std::get<1>(info) == interfaceTypeID) { 72 LLVM_DEBUG(llvm::dbgs() 73 << "[" DEBUG_TYPE 74 "] repeated interface object interface registration"); 75 return; 76 } 77 } 78 79 ifaces.objectInterfaces.emplace_back(objectID, interfaceTypeID, allocator); 80 } 81 82 DialectAllocatorFunctionRef 83 DialectRegistry::getDialectAllocator(StringRef name) const { 84 auto it = registry.find(name.str()); 85 if (it == registry.end()) 86 return nullptr; 87 return it->second.second; 88 } 89 90 void DialectRegistry::insert(TypeID typeID, StringRef name, 91 DialectAllocatorFunction ctor) { 92 auto inserted = registry.insert( 93 std::make_pair(std::string(name), std::make_pair(typeID, ctor))); 94 if (!inserted.second && inserted.first->second.first != typeID) { 95 llvm::report_fatal_error( 96 "Trying to register different dialects for the same namespace: " + 97 name); 98 } 99 } 100 101 void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const { 102 auto it = interfaces.find(dialect->getTypeID()); 103 if (it == interfaces.end()) 104 return; 105 106 // Add an interface if it is not already present. 107 for (const auto &kvp : it->getSecond().dialectInterfaces) { 108 if (dialect->getRegisteredInterface(kvp.first)) 109 continue; 110 dialect->addInterface(kvp.second(dialect)); 111 } 112 113 // Add attribute, operation and type interfaces. 114 for (const auto &info : it->getSecond().objectInterfaces) 115 std::get<2>(info)(dialect->getContext()); 116 } 117 118 //===----------------------------------------------------------------------===// 119 // Dialect 120 //===----------------------------------------------------------------------===// 121 122 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id) 123 : name(name), dialectID(id), context(context) { 124 assert(isValidNamespace(name) && "invalid dialect namespace"); 125 } 126 127 Dialect::~Dialect() {} 128 129 /// Verify an attribute from this dialect on the argument at 'argIndex' for 130 /// the region at 'regionIndex' on the given operation. Returns failure if 131 /// the verification failed, success otherwise. This hook may optionally be 132 /// invoked from any operation containing a region. 133 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned, 134 NamedAttribute) { 135 return success(); 136 } 137 138 /// Verify an attribute from this dialect on the result at 'resultIndex' for 139 /// the region at 'regionIndex' on the given operation. Returns failure if 140 /// the verification failed, success otherwise. This hook may optionally be 141 /// invoked from any operation containing a region. 142 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned, 143 unsigned, NamedAttribute) { 144 return success(); 145 } 146 147 /// Parse an attribute registered to this dialect. 148 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const { 149 parser.emitError(parser.getNameLoc()) 150 << "dialect '" << getNamespace() 151 << "' provides no attribute parsing hook"; 152 return Attribute(); 153 } 154 155 /// Parse a type registered to this dialect. 156 Type Dialect::parseType(DialectAsmParser &parser) const { 157 // If this dialect allows unknown types, then represent this with OpaqueType. 158 if (allowsUnknownTypes()) { 159 Identifier ns = Identifier::get(getNamespace(), getContext()); 160 return OpaqueType::get(ns, parser.getFullSymbolSpec()); 161 } 162 163 parser.emitError(parser.getNameLoc()) 164 << "dialect '" << getNamespace() << "' provides no type parsing hook"; 165 return Type(); 166 } 167 168 Optional<Dialect::ParseOpHook> 169 Dialect::getParseOperationHook(StringRef opName) const { 170 return None; 171 } 172 173 llvm::unique_function<void(Operation *, OpAsmPrinter &printer)> 174 Dialect::getOperationPrinter(Operation *op) const { 175 assert(op->getDialect() == this && 176 "Dialect hook invoked on non-dialect owned operation"); 177 return nullptr; 178 } 179 180 /// Utility function that returns if the given string is a valid dialect 181 /// namespace. 182 bool Dialect::isValidNamespace(StringRef str) { 183 llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$"); 184 return dialectNameRegex.match(str); 185 } 186 187 /// Register a set of dialect interfaces with this dialect instance. 188 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) { 189 auto it = registeredInterfaces.try_emplace(interface->getID(), 190 std::move(interface)); 191 (void)it; 192 assert(it.second && "interface kind has already been registered"); 193 } 194 195 //===----------------------------------------------------------------------===// 196 // Dialect Interface 197 //===----------------------------------------------------------------------===// 198 199 DialectInterface::~DialectInterface() {} 200 201 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase( 202 MLIRContext *ctx, TypeID interfaceKind) { 203 for (auto *dialect : ctx->getLoadedDialects()) { 204 if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) { 205 interfaces.insert(interface); 206 orderedInterfaces.push_back(interface); 207 } 208 } 209 } 210 211 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {} 212 213 /// Get the interface for the dialect of given operation, or null if one 214 /// is not registered. 215 const DialectInterface * 216 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const { 217 return getInterfaceFor(op->getDialect()); 218 } 219