1 //===- Dialect.cpp - Dialect implementation -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Dialect.h" 10 #include "mlir/IR/BuiltinDialect.h" 11 #include "mlir/IR/Diagnostics.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/DialectInterface.h" 14 #include "mlir/IR/MLIRContext.h" 15 #include "mlir/IR/Operation.h" 16 #include "llvm/ADT/MapVector.h" 17 #include "llvm/ADT/Twine.h" 18 #include "llvm/Support/Debug.h" 19 #include "llvm/Support/ManagedStatic.h" 20 #include "llvm/Support/Regex.h" 21 22 #define DEBUG_TYPE "dialect" 23 24 using namespace mlir; 25 using namespace detail; 26 27 DialectAsmParser::~DialectAsmParser() {} 28 29 //===----------------------------------------------------------------------===// 30 // DialectRegistry 31 //===----------------------------------------------------------------------===// 32 33 void DialectRegistry::addDialectInterface( 34 StringRef dialectName, TypeID interfaceTypeID, 35 DialectInterfaceAllocatorFunction allocator) { 36 assert(allocator && "unexpected null interface allocation function"); 37 auto it = registry.find(dialectName.str()); 38 assert(it != registry.end() && 39 "adding an interface for an unregistered dialect"); 40 41 // Bail out if the interface with the given ID is already in the registry for 42 // the given dialect. We expect a small number (dozens) of interfaces so a 43 // linear search is fine here. 44 auto &ifaces = interfaces[it->second.first]; 45 for (const auto &kvp : ifaces.dialectInterfaces) { 46 if (kvp.first == interfaceTypeID) { 47 LLVM_DEBUG(llvm::dbgs() 48 << "[" DEBUG_TYPE 49 "] repeated interface registration for dialect " 50 << dialectName); 51 return; 52 } 53 } 54 55 ifaces.dialectInterfaces.emplace_back(interfaceTypeID, allocator); 56 } 57 58 void DialectRegistry::addObjectInterface( 59 StringRef dialectName, TypeID interfaceTypeID, 60 ObjectInterfaceAllocatorFunction allocator) { 61 assert(allocator && "unexpected null interface allocation function"); 62 63 // Builtin dialect has an empty prefix and is always registered. 64 TypeID dialectTypeID; 65 if (!dialectName.empty()) { 66 auto it = registry.find(dialectName.str()); 67 assert(it != registry.end() && 68 "adding an interface for an op from an unregistered dialect"); 69 dialectTypeID = it->second.first; 70 } else { 71 dialectTypeID = TypeID::get<BuiltinDialect>(); 72 } 73 74 auto &ifaces = interfaces[dialectTypeID]; 75 for (const auto &kvp : ifaces.objectInterfaces) { 76 if (kvp.first == interfaceTypeID) { 77 LLVM_DEBUG(llvm::dbgs() 78 << "[" DEBUG_TYPE 79 "] repeated interface object interface registration"); 80 return; 81 } 82 } 83 84 ifaces.objectInterfaces.emplace_back(interfaceTypeID, allocator); 85 } 86 87 DialectAllocatorFunctionRef 88 DialectRegistry::getDialectAllocator(StringRef name) const { 89 auto it = registry.find(name.str()); 90 if (it == registry.end()) 91 return nullptr; 92 return it->second.second; 93 } 94 95 void DialectRegistry::insert(TypeID typeID, StringRef name, 96 DialectAllocatorFunction ctor) { 97 auto inserted = registry.insert( 98 std::make_pair(std::string(name), std::make_pair(typeID, ctor))); 99 if (!inserted.second && inserted.first->second.first != typeID) { 100 llvm::report_fatal_error( 101 "Trying to register different dialects for the same namespace: " + 102 name); 103 } 104 } 105 106 void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const { 107 auto it = interfaces.find(dialect->getTypeID()); 108 if (it == interfaces.end()) 109 return; 110 111 // Add an interface if it is not already present. 112 for (const auto &kvp : it->getSecond().dialectInterfaces) { 113 if (dialect->getRegisteredInterface(kvp.first)) 114 continue; 115 dialect->addInterface(kvp.second(dialect)); 116 } 117 118 // Add attribute, operation and type interfaces. 119 for (const auto &kvp : it->getSecond().objectInterfaces) 120 kvp.second(dialect->getContext()); 121 } 122 123 //===----------------------------------------------------------------------===// 124 // Dialect 125 //===----------------------------------------------------------------------===// 126 127 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id) 128 : name(name), dialectID(id), context(context) { 129 assert(isValidNamespace(name) && "invalid dialect namespace"); 130 } 131 132 Dialect::~Dialect() {} 133 134 /// Verify an attribute from this dialect on the argument at 'argIndex' for 135 /// the region at 'regionIndex' on the given operation. Returns failure if 136 /// the verification failed, success otherwise. This hook may optionally be 137 /// invoked from any operation containing a region. 138 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned, 139 NamedAttribute) { 140 return success(); 141 } 142 143 /// Verify an attribute from this dialect on the result at 'resultIndex' for 144 /// the region at 'regionIndex' on the given operation. Returns failure if 145 /// the verification failed, success otherwise. This hook may optionally be 146 /// invoked from any operation containing a region. 147 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned, 148 unsigned, NamedAttribute) { 149 return success(); 150 } 151 152 /// Parse an attribute registered to this dialect. 153 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const { 154 parser.emitError(parser.getNameLoc()) 155 << "dialect '" << getNamespace() 156 << "' provides no attribute parsing hook"; 157 return Attribute(); 158 } 159 160 /// Parse a type registered to this dialect. 161 Type Dialect::parseType(DialectAsmParser &parser) const { 162 // If this dialect allows unknown types, then represent this with OpaqueType. 163 if (allowsUnknownTypes()) { 164 Identifier ns = Identifier::get(getNamespace(), getContext()); 165 return OpaqueType::get(ns, parser.getFullSymbolSpec()); 166 } 167 168 parser.emitError(parser.getNameLoc()) 169 << "dialect '" << getNamespace() << "' provides no type parsing hook"; 170 return Type(); 171 } 172 173 Optional<Dialect::ParseOpHook> 174 Dialect::getParseOperationHook(StringRef opName) const { 175 return None; 176 } 177 178 LogicalResult Dialect::printOperation(Operation *op, 179 OpAsmPrinter &printer) const { 180 assert(op->getDialect() == this && 181 "Dialect hook invoked on non-dialect owned operation"); 182 return failure(); 183 } 184 185 /// Utility function that returns if the given string is a valid dialect 186 /// namespace. 187 bool Dialect::isValidNamespace(StringRef str) { 188 if (str.empty()) 189 return true; 190 llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$"); 191 return dialectNameRegex.match(str); 192 } 193 194 /// Register a set of dialect interfaces with this dialect instance. 195 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) { 196 auto it = registeredInterfaces.try_emplace(interface->getID(), 197 std::move(interface)); 198 (void)it; 199 assert(it.second && "interface kind has already been registered"); 200 } 201 202 //===----------------------------------------------------------------------===// 203 // Dialect Interface 204 //===----------------------------------------------------------------------===// 205 206 DialectInterface::~DialectInterface() {} 207 208 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase( 209 MLIRContext *ctx, TypeID interfaceKind) { 210 for (auto *dialect : ctx->getLoadedDialects()) { 211 if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) { 212 interfaces.insert(interface); 213 orderedInterfaces.push_back(interface); 214 } 215 } 216 } 217 218 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {} 219 220 /// Get the interface for the dialect of given operation, or null if one 221 /// is not registered. 222 const DialectInterface * 223 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const { 224 return getInterfaceFor(op->getDialect()); 225 } 226