xref: /llvm-project/mlir/lib/IR/Dialect.cpp (revision 9b50844fd798b5a81afd4aeb44b053d622747a42)
1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/BuiltinDialect.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/DialectInterface.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "mlir/IR/Operation.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/ManagedStatic.h"
20 #include "llvm/Support/Regex.h"
21 
22 #define DEBUG_TYPE "dialect"
23 
24 using namespace mlir;
25 using namespace detail;
26 
27 DialectAsmParser::~DialectAsmParser() {}
28 
29 //===----------------------------------------------------------------------===//
30 // DialectRegistry
31 //===----------------------------------------------------------------------===//
32 
33 DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
34 
35 void DialectRegistry::addDialectInterface(
36     StringRef dialectName, TypeID interfaceTypeID,
37     DialectInterfaceAllocatorFunction allocator) {
38   assert(allocator && "unexpected null interface allocation function");
39   auto it = registry.find(dialectName.str());
40   assert(it != registry.end() &&
41          "adding an interface for an unregistered dialect");
42 
43   // Bail out if the interface with the given ID is already in the registry for
44   // the given dialect. We expect a small number (dozens) of interfaces so a
45   // linear search is fine here.
46   auto &ifaces = interfaces[it->second.first];
47   for (const auto &kvp : ifaces.dialectInterfaces) {
48     if (kvp.first == interfaceTypeID) {
49       LLVM_DEBUG(llvm::dbgs()
50                  << "[" DEBUG_TYPE
51                     "] repeated interface registration for dialect "
52                  << dialectName);
53       return;
54     }
55   }
56 
57   ifaces.dialectInterfaces.emplace_back(interfaceTypeID, allocator);
58 }
59 
60 void DialectRegistry::addObjectInterface(
61     StringRef dialectName, TypeID objectID, TypeID interfaceTypeID,
62     ObjectInterfaceAllocatorFunction allocator) {
63   assert(allocator && "unexpected null interface allocation function");
64 
65   auto it = registry.find(dialectName.str());
66   assert(it != registry.end() &&
67          "adding an interface for an op from an unregistered dialect");
68 
69   auto dialectID = it->second.first;
70   auto &ifaces = interfaces[dialectID];
71 
72   for (const auto &info : ifaces.objectInterfaces) {
73     if (std::get<0>(info) == objectID && std::get<1>(info) == interfaceTypeID) {
74       LLVM_DEBUG(llvm::dbgs()
75                  << "[" DEBUG_TYPE
76                     "] repeated interface object interface registration");
77       return;
78     }
79   }
80 
81   ifaces.objectInterfaces.emplace_back(objectID, interfaceTypeID, allocator);
82 }
83 
84 DialectAllocatorFunctionRef
85 DialectRegistry::getDialectAllocator(StringRef name) const {
86   auto it = registry.find(name.str());
87   if (it == registry.end())
88     return nullptr;
89   return it->second.second;
90 }
91 
92 void DialectRegistry::insert(TypeID typeID, StringRef name,
93                              DialectAllocatorFunction ctor) {
94   auto inserted = registry.insert(
95       std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
96   if (!inserted.second && inserted.first->second.first != typeID) {
97     llvm::report_fatal_error(
98         "Trying to register different dialects for the same namespace: " +
99         name);
100   }
101 }
102 
103 void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
104   auto it = interfaces.find(dialect->getTypeID());
105   if (it == interfaces.end())
106     return;
107 
108   // Add an interface if it is not already present.
109   for (const auto &kvp : it->getSecond().dialectInterfaces) {
110     if (dialect->getRegisteredInterface(kvp.first))
111       continue;
112     dialect->addInterface(kvp.second(dialect));
113   }
114 
115   // Add attribute, operation and type interfaces.
116   for (const auto &info : it->getSecond().objectInterfaces)
117     std::get<2>(info)(dialect->getContext());
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // Dialect
122 //===----------------------------------------------------------------------===//
123 
124 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
125     : name(name), dialectID(id), context(context) {
126   assert(isValidNamespace(name) && "invalid dialect namespace");
127 }
128 
129 Dialect::~Dialect() {}
130 
131 /// Verify an attribute from this dialect on the argument at 'argIndex' for
132 /// the region at 'regionIndex' on the given operation. Returns failure if
133 /// the verification failed, success otherwise. This hook may optionally be
134 /// invoked from any operation containing a region.
135 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
136                                                 NamedAttribute) {
137   return success();
138 }
139 
140 /// Verify an attribute from this dialect on the result at 'resultIndex' for
141 /// the region at 'regionIndex' on the given operation. Returns failure if
142 /// the verification failed, success otherwise. This hook may optionally be
143 /// invoked from any operation containing a region.
144 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
145                                                    unsigned, NamedAttribute) {
146   return success();
147 }
148 
149 /// Parse an attribute registered to this dialect.
150 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
151   parser.emitError(parser.getNameLoc())
152       << "dialect '" << getNamespace()
153       << "' provides no attribute parsing hook";
154   return Attribute();
155 }
156 
157 /// Parse a type registered to this dialect.
158 Type Dialect::parseType(DialectAsmParser &parser) const {
159   // If this dialect allows unknown types, then represent this with OpaqueType.
160   if (allowsUnknownTypes()) {
161     Identifier ns = Identifier::get(getNamespace(), getContext());
162     return OpaqueType::get(ns, parser.getFullSymbolSpec());
163   }
164 
165   parser.emitError(parser.getNameLoc())
166       << "dialect '" << getNamespace() << "' provides no type parsing hook";
167   return Type();
168 }
169 
170 Optional<Dialect::ParseOpHook>
171 Dialect::getParseOperationHook(StringRef opName) const {
172   return None;
173 }
174 
175 LogicalResult Dialect::printOperation(Operation *op,
176                                       OpAsmPrinter &printer) const {
177   assert(op->getDialect() == this &&
178          "Dialect hook invoked on non-dialect owned operation");
179   return failure();
180 }
181 
182 /// Utility function that returns if the given string is a valid dialect
183 /// namespace.
184 bool Dialect::isValidNamespace(StringRef str) {
185   llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
186   return dialectNameRegex.match(str);
187 }
188 
189 /// Register a set of dialect interfaces with this dialect instance.
190 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
191   auto it = registeredInterfaces.try_emplace(interface->getID(),
192                                              std::move(interface));
193   (void)it;
194   assert(it.second && "interface kind has already been registered");
195 }
196 
197 //===----------------------------------------------------------------------===//
198 // Dialect Interface
199 //===----------------------------------------------------------------------===//
200 
201 DialectInterface::~DialectInterface() {}
202 
203 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
204     MLIRContext *ctx, TypeID interfaceKind) {
205   for (auto *dialect : ctx->getLoadedDialects()) {
206     if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
207       interfaces.insert(interface);
208       orderedInterfaces.push_back(interface);
209     }
210   }
211 }
212 
213 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {}
214 
215 /// Get the interface for the dialect of given operation, or null if one
216 /// is not registered.
217 const DialectInterface *
218 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
219   return getInterfaceFor(op->getDialect());
220 }
221