1 //===- Dialect.cpp - Dialect implementation -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Dialect.h" 10 #include "mlir/IR/BuiltinDialect.h" 11 #include "mlir/IR/Diagnostics.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/DialectInterface.h" 14 #include "mlir/IR/MLIRContext.h" 15 #include "mlir/IR/Operation.h" 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/Twine.h" 18 #include "llvm/Support/Debug.h" 19 #include "llvm/Support/ManagedStatic.h" 20 #include "llvm/Support/Regex.h" 21 22 #define DEBUG_TYPE "dialect" 23 24 using namespace mlir; 25 using namespace detail; 26 27 DialectAsmParser::~DialectAsmParser() {} 28 29 //===----------------------------------------------------------------------===// 30 // DialectRegistry 31 //===----------------------------------------------------------------------===// 32 33 DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); } 34 35 void DialectRegistry::addDialectInterface( 36 StringRef dialectName, TypeID interfaceTypeID, 37 DialectInterfaceAllocatorFunction allocator) { 38 assert(allocator && "unexpected null interface allocation function"); 39 auto it = registry.find(dialectName.str()); 40 assert(it != registry.end() && 41 "adding an interface for an unregistered dialect"); 42 43 // Bail out if the interface with the given ID is already in the registry for 44 // the given dialect. We expect a small number (dozens) of interfaces so a 45 // linear search is fine here. 46 auto &ifaces = interfaces[it->second.first]; 47 for (const auto &kvp : ifaces.dialectInterfaces) { 48 if (kvp.first == interfaceTypeID) { 49 LLVM_DEBUG(llvm::dbgs() 50 << "[" DEBUG_TYPE 51 "] repeated interface registration for dialect " 52 << dialectName); 53 return; 54 } 55 } 56 57 ifaces.dialectInterfaces.emplace_back(interfaceTypeID, allocator); 58 } 59 60 void DialectRegistry::addObjectInterface( 61 StringRef dialectName, TypeID objectID, TypeID interfaceTypeID, 62 ObjectInterfaceAllocatorFunction allocator) { 63 assert(allocator && "unexpected null interface allocation function"); 64 65 auto it = registry.find(dialectName.str()); 66 assert(it != registry.end() && 67 "adding an interface for an op from an unregistered dialect"); 68 69 auto dialectID = it->second.first; 70 auto &ifaces = interfaces[dialectID]; 71 72 for (const auto &info : ifaces.objectInterfaces) { 73 if (std::get<0>(info) == objectID && std::get<1>(info) == interfaceTypeID) { 74 LLVM_DEBUG(llvm::dbgs() 75 << "[" DEBUG_TYPE 76 "] repeated interface object interface registration"); 77 return; 78 } 79 } 80 81 ifaces.objectInterfaces.emplace_back(objectID, interfaceTypeID, allocator); 82 } 83 84 DialectAllocatorFunctionRef 85 DialectRegistry::getDialectAllocator(StringRef name) const { 86 auto it = registry.find(name.str()); 87 if (it == registry.end()) 88 return nullptr; 89 return it->second.second; 90 } 91 92 void DialectRegistry::insert(TypeID typeID, StringRef name, 93 DialectAllocatorFunction ctor) { 94 auto inserted = registry.insert( 95 std::make_pair(std::string(name), std::make_pair(typeID, ctor))); 96 if (!inserted.second && inserted.first->second.first != typeID) { 97 llvm::report_fatal_error( 98 "Trying to register different dialects for the same namespace: " + 99 name); 100 } 101 } 102 103 void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const { 104 auto it = interfaces.find(dialect->getTypeID()); 105 if (it == interfaces.end()) 106 return; 107 108 // Add an interface if it is not already present. 109 for (const auto &kvp : it->getSecond().dialectInterfaces) { 110 if (dialect->getRegisteredInterface(kvp.first)) 111 continue; 112 dialect->addInterface(kvp.second(dialect)); 113 } 114 115 // Add attribute, operation and type interfaces. 116 for (const auto &info : it->getSecond().objectInterfaces) 117 std::get<2>(info)(dialect->getContext()); 118 } 119 120 //===----------------------------------------------------------------------===// 121 // Dialect 122 //===----------------------------------------------------------------------===// 123 124 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id) 125 : name(name), dialectID(id), context(context) { 126 assert(isValidNamespace(name) && "invalid dialect namespace"); 127 } 128 129 Dialect::~Dialect() {} 130 131 /// Verify an attribute from this dialect on the argument at 'argIndex' for 132 /// the region at 'regionIndex' on the given operation. Returns failure if 133 /// the verification failed, success otherwise. This hook may optionally be 134 /// invoked from any operation containing a region. 135 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned, 136 NamedAttribute) { 137 return success(); 138 } 139 140 /// Verify an attribute from this dialect on the result at 'resultIndex' for 141 /// the region at 'regionIndex' on the given operation. Returns failure if 142 /// the verification failed, success otherwise. This hook may optionally be 143 /// invoked from any operation containing a region. 144 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned, 145 unsigned, NamedAttribute) { 146 return success(); 147 } 148 149 /// Parse an attribute registered to this dialect. 150 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const { 151 parser.emitError(parser.getNameLoc()) 152 << "dialect '" << getNamespace() 153 << "' provides no attribute parsing hook"; 154 return Attribute(); 155 } 156 157 /// Parse a type registered to this dialect. 158 Type Dialect::parseType(DialectAsmParser &parser) const { 159 // If this dialect allows unknown types, then represent this with OpaqueType. 160 if (allowsUnknownTypes()) { 161 Identifier ns = Identifier::get(getNamespace(), getContext()); 162 return OpaqueType::get(ns, parser.getFullSymbolSpec()); 163 } 164 165 parser.emitError(parser.getNameLoc()) 166 << "dialect '" << getNamespace() << "' provides no type parsing hook"; 167 return Type(); 168 } 169 170 Optional<Dialect::ParseOpHook> 171 Dialect::getParseOperationHook(StringRef opName) const { 172 return None; 173 } 174 175 LogicalResult Dialect::printOperation(Operation *op, 176 OpAsmPrinter &printer) const { 177 assert(op->getDialect() == this && 178 "Dialect hook invoked on non-dialect owned operation"); 179 return failure(); 180 } 181 182 /// Utility function that returns if the given string is a valid dialect 183 /// namespace. 184 bool Dialect::isValidNamespace(StringRef str) { 185 llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$"); 186 return dialectNameRegex.match(str); 187 } 188 189 /// Register a set of dialect interfaces with this dialect instance. 190 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) { 191 auto it = registeredInterfaces.try_emplace(interface->getID(), 192 std::move(interface)); 193 (void)it; 194 assert(it.second && "interface kind has already been registered"); 195 } 196 197 //===----------------------------------------------------------------------===// 198 // Dialect Interface 199 //===----------------------------------------------------------------------===// 200 201 DialectInterface::~DialectInterface() {} 202 203 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase( 204 MLIRContext *ctx, TypeID interfaceKind) { 205 for (auto *dialect : ctx->getLoadedDialects()) { 206 if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) { 207 interfaces.insert(interface); 208 orderedInterfaces.push_back(interface); 209 } 210 } 211 } 212 213 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {} 214 215 /// Get the interface for the dialect of given operation, or null if one 216 /// is not registered. 217 const DialectInterface * 218 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const { 219 return getInterfaceFor(op->getDialect()); 220 } 221