1 //===- Dialect.cpp - Dialect implementation -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Dialect.h" 10 #include "mlir/IR/BuiltinDialect.h" 11 #include "mlir/IR/Diagnostics.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/DialectInterface.h" 14 #include "mlir/IR/MLIRContext.h" 15 #include "mlir/IR/Operation.h" 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/Twine.h" 18 #include "llvm/Support/Debug.h" 19 #include "llvm/Support/ManagedStatic.h" 20 #include "llvm/Support/Regex.h" 21 22 #define DEBUG_TYPE "dialect" 23 24 using namespace mlir; 25 using namespace detail; 26 27 DialectAsmParser::~DialectAsmParser() {} 28 29 //===----------------------------------------------------------------------===// 30 // DialectRegistry 31 //===----------------------------------------------------------------------===// 32 33 DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); } 34 35 void DialectRegistry::addDialectInterface( 36 StringRef dialectName, TypeID interfaceTypeID, 37 DialectInterfaceAllocatorFunction allocator) { 38 assert(allocator && "unexpected null interface allocation function"); 39 auto it = registry.find(dialectName.str()); 40 assert(it != registry.end() && 41 "adding an interface for an unregistered dialect"); 42 43 // Bail out if the interface with the given ID is already in the registry for 44 // the given dialect. We expect a small number (dozens) of interfaces so a 45 // linear search is fine here. 46 auto &ifaces = interfaces[it->second.first]; 47 for (const auto &kvp : ifaces.dialectInterfaces) { 48 if (kvp.first == interfaceTypeID) { 49 LLVM_DEBUG(llvm::dbgs() 50 << "[" DEBUG_TYPE 51 "] repeated interface registration for dialect " 52 << dialectName); 53 return; 54 } 55 } 56 57 ifaces.dialectInterfaces.emplace_back(interfaceTypeID, allocator); 58 } 59 60 void DialectRegistry::addObjectInterface( 61 StringRef dialectName, TypeID interfaceTypeID, 62 ObjectInterfaceAllocatorFunction allocator) { 63 assert(allocator && "unexpected null interface allocation function"); 64 auto it = registry.find(dialectName.str()); 65 assert(it != registry.end() && 66 "adding an interface for an op from an unregistered dialect"); 67 68 auto &ifaces = interfaces[it->second.first]; 69 for (const auto &kvp : ifaces.objectInterfaces) { 70 if (kvp.first == interfaceTypeID) { 71 LLVM_DEBUG(llvm::dbgs() 72 << "[" DEBUG_TYPE 73 "] repeated interface object interface registration"); 74 return; 75 } 76 } 77 78 ifaces.objectInterfaces.emplace_back(interfaceTypeID, allocator); 79 } 80 81 DialectAllocatorFunctionRef 82 DialectRegistry::getDialectAllocator(StringRef name) const { 83 auto it = registry.find(name.str()); 84 if (it == registry.end()) 85 return nullptr; 86 return it->second.second; 87 } 88 89 void DialectRegistry::insert(TypeID typeID, StringRef name, 90 DialectAllocatorFunction ctor) { 91 auto inserted = registry.insert( 92 std::make_pair(std::string(name), std::make_pair(typeID, ctor))); 93 if (!inserted.second && inserted.first->second.first != typeID) { 94 llvm::report_fatal_error( 95 "Trying to register different dialects for the same namespace: " + 96 name); 97 } 98 } 99 100 void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const { 101 auto it = interfaces.find(dialect->getTypeID()); 102 if (it == interfaces.end()) 103 return; 104 105 // Add an interface if it is not already present. 106 for (const auto &kvp : it->getSecond().dialectInterfaces) { 107 if (dialect->getRegisteredInterface(kvp.first)) 108 continue; 109 dialect->addInterface(kvp.second(dialect)); 110 } 111 112 // Add attribute, operation and type interfaces. 113 for (const auto &kvp : it->getSecond().objectInterfaces) 114 kvp.second(dialect->getContext()); 115 } 116 117 //===----------------------------------------------------------------------===// 118 // Dialect 119 //===----------------------------------------------------------------------===// 120 121 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id) 122 : name(name), dialectID(id), context(context) { 123 assert(isValidNamespace(name) && "invalid dialect namespace"); 124 } 125 126 Dialect::~Dialect() {} 127 128 /// Verify an attribute from this dialect on the argument at 'argIndex' for 129 /// the region at 'regionIndex' on the given operation. Returns failure if 130 /// the verification failed, success otherwise. This hook may optionally be 131 /// invoked from any operation containing a region. 132 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned, 133 NamedAttribute) { 134 return success(); 135 } 136 137 /// Verify an attribute from this dialect on the result at 'resultIndex' for 138 /// the region at 'regionIndex' on the given operation. Returns failure if 139 /// the verification failed, success otherwise. This hook may optionally be 140 /// invoked from any operation containing a region. 141 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned, 142 unsigned, NamedAttribute) { 143 return success(); 144 } 145 146 /// Parse an attribute registered to this dialect. 147 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const { 148 parser.emitError(parser.getNameLoc()) 149 << "dialect '" << getNamespace() 150 << "' provides no attribute parsing hook"; 151 return Attribute(); 152 } 153 154 /// Parse a type registered to this dialect. 155 Type Dialect::parseType(DialectAsmParser &parser) const { 156 // If this dialect allows unknown types, then represent this with OpaqueType. 157 if (allowsUnknownTypes()) { 158 Identifier ns = Identifier::get(getNamespace(), getContext()); 159 return OpaqueType::get(ns, parser.getFullSymbolSpec()); 160 } 161 162 parser.emitError(parser.getNameLoc()) 163 << "dialect '" << getNamespace() << "' provides no type parsing hook"; 164 return Type(); 165 } 166 167 Optional<Dialect::ParseOpHook> 168 Dialect::getParseOperationHook(StringRef opName) const { 169 return None; 170 } 171 172 LogicalResult Dialect::printOperation(Operation *op, 173 OpAsmPrinter &printer) const { 174 assert(op->getDialect() == this && 175 "Dialect hook invoked on non-dialect owned operation"); 176 return failure(); 177 } 178 179 /// Utility function that returns if the given string is a valid dialect 180 /// namespace. 181 bool Dialect::isValidNamespace(StringRef str) { 182 llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$"); 183 return dialectNameRegex.match(str); 184 } 185 186 /// Register a set of dialect interfaces with this dialect instance. 187 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) { 188 auto it = registeredInterfaces.try_emplace(interface->getID(), 189 std::move(interface)); 190 (void)it; 191 assert(it.second && "interface kind has already been registered"); 192 } 193 194 //===----------------------------------------------------------------------===// 195 // Dialect Interface 196 //===----------------------------------------------------------------------===// 197 198 DialectInterface::~DialectInterface() {} 199 200 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase( 201 MLIRContext *ctx, TypeID interfaceKind) { 202 for (auto *dialect : ctx->getLoadedDialects()) { 203 if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) { 204 interfaces.insert(interface); 205 orderedInterfaces.push_back(interface); 206 } 207 } 208 } 209 210 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {} 211 212 /// Get the interface for the dialect of given operation, or null if one 213 /// is not registered. 214 const DialectInterface * 215 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const { 216 return getInterfaceFor(op->getDialect()); 217 } 218