1 //===- Dialect.cpp - Dialect implementation -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Dialect.h" 10 #include "mlir/IR/Diagnostics.h" 11 #include "mlir/IR/DialectImplementation.h" 12 #include "mlir/IR/DialectInterface.h" 13 #include "mlir/IR/MLIRContext.h" 14 #include "mlir/IR/Operation.h" 15 #include "llvm/ADT/MapVector.h" 16 #include "llvm/ADT/Twine.h" 17 #include "llvm/Support/Debug.h" 18 #include "llvm/Support/ManagedStatic.h" 19 #include "llvm/Support/Regex.h" 20 21 #define DEBUG_TYPE "dialect" 22 23 using namespace mlir; 24 using namespace detail; 25 26 DialectAsmParser::~DialectAsmParser() {} 27 28 //===----------------------------------------------------------------------===// 29 // DialectRegistry 30 //===----------------------------------------------------------------------===// 31 32 void DialectRegistry::addDialectInterface( 33 StringRef dialectName, TypeID interfaceTypeID, 34 InterfaceAllocatorFunction allocator) { 35 assert(allocator && "unexpected null interface allocation function"); 36 auto it = registry.find(dialectName.str()); 37 assert(it != registry.end() && 38 "adding an interface for an unregistered dialect"); 39 40 // Bail out if the interface with the given ID is already in the registry for 41 // the given dialect. We expect a small number (dozens) of interfaces so a 42 // linear search is fine here. 43 auto &dialectInterfaces = interfaces[it->second.first]; 44 for (const auto &kvp : dialectInterfaces) { 45 if (kvp.first == interfaceTypeID) { 46 LLVM_DEBUG(llvm::dbgs() 47 << "[" DEBUG_TYPE 48 "] repeated interface registration for dialect " 49 << dialectName); 50 return; 51 } 52 } 53 54 dialectInterfaces.emplace_back(interfaceTypeID, allocator); 55 } 56 57 DialectAllocatorFunctionRef 58 DialectRegistry::getDialectAllocator(StringRef name) const { 59 auto it = registry.find(name.str()); 60 if (it == registry.end()) 61 return nullptr; 62 return it->second.second; 63 } 64 65 void DialectRegistry::insert(TypeID typeID, StringRef name, 66 DialectAllocatorFunction ctor) { 67 auto inserted = registry.insert( 68 std::make_pair(std::string(name), std::make_pair(typeID, ctor))); 69 if (!inserted.second && inserted.first->second.first != typeID) { 70 llvm::report_fatal_error( 71 "Trying to register different dialects for the same namespace: " + 72 name); 73 } 74 } 75 76 void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const { 77 auto it = interfaces.find(dialect->getTypeID()); 78 if (it == interfaces.end()) 79 return; 80 81 // Add an interface if it is not already present. 82 for (const auto &kvp : it->second) { 83 if (dialect->getRegisteredInterface(kvp.first)) 84 continue; 85 dialect->addInterface(kvp.second(dialect)); 86 } 87 } 88 89 //===----------------------------------------------------------------------===// 90 // Dialect 91 //===----------------------------------------------------------------------===// 92 93 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id) 94 : name(name), dialectID(id), context(context) { 95 assert(isValidNamespace(name) && "invalid dialect namespace"); 96 } 97 98 Dialect::~Dialect() {} 99 100 /// Verify an attribute from this dialect on the argument at 'argIndex' for 101 /// the region at 'regionIndex' on the given operation. Returns failure if 102 /// the verification failed, success otherwise. This hook may optionally be 103 /// invoked from any operation containing a region. 104 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned, 105 NamedAttribute) { 106 return success(); 107 } 108 109 /// Verify an attribute from this dialect on the result at 'resultIndex' for 110 /// the region at 'regionIndex' on the given operation. Returns failure if 111 /// the verification failed, success otherwise. This hook may optionally be 112 /// invoked from any operation containing a region. 113 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned, 114 unsigned, NamedAttribute) { 115 return success(); 116 } 117 118 /// Parse an attribute registered to this dialect. 119 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const { 120 parser.emitError(parser.getNameLoc()) 121 << "dialect '" << getNamespace() 122 << "' provides no attribute parsing hook"; 123 return Attribute(); 124 } 125 126 /// Parse a type registered to this dialect. 127 Type Dialect::parseType(DialectAsmParser &parser) const { 128 // If this dialect allows unknown types, then represent this with OpaqueType. 129 if (allowsUnknownTypes()) { 130 Identifier ns = Identifier::get(getNamespace(), getContext()); 131 return OpaqueType::get(ns, parser.getFullSymbolSpec()); 132 } 133 134 parser.emitError(parser.getNameLoc()) 135 << "dialect '" << getNamespace() << "' provides no type parsing hook"; 136 return Type(); 137 } 138 139 Optional<Dialect::ParseOpHook> 140 Dialect::getParseOperationHook(StringRef opName) const { 141 return None; 142 } 143 144 LogicalResult Dialect::printOperation(Operation *op, 145 OpAsmPrinter &printer) const { 146 assert(op->getDialect() == this && 147 "Dialect hook invoked on non-dialect owned operation"); 148 return failure(); 149 } 150 151 /// Utility function that returns if the given string is a valid dialect 152 /// namespace. 153 bool Dialect::isValidNamespace(StringRef str) { 154 if (str.empty()) 155 return true; 156 llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$"); 157 return dialectNameRegex.match(str); 158 } 159 160 /// Register a set of dialect interfaces with this dialect instance. 161 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) { 162 auto it = registeredInterfaces.try_emplace(interface->getID(), 163 std::move(interface)); 164 (void)it; 165 assert(it.second && "interface kind has already been registered"); 166 } 167 168 //===----------------------------------------------------------------------===// 169 // Dialect Interface 170 //===----------------------------------------------------------------------===// 171 172 DialectInterface::~DialectInterface() {} 173 174 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase( 175 MLIRContext *ctx, TypeID interfaceKind) { 176 for (auto *dialect : ctx->getLoadedDialects()) { 177 if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) { 178 interfaces.insert(interface); 179 orderedInterfaces.push_back(interface); 180 } 181 } 182 } 183 184 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {} 185 186 /// Get the interface for the dialect of given operation, or null if one 187 /// is not registered. 188 const DialectInterface * 189 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const { 190 return getInterfaceFor(op->getDialect()); 191 } 192