xref: /llvm-project/mlir/lib/IR/Dialect.cpp (revision 531206310a27477f088f672f5e6fd688d77d9292)
1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/BuiltinDialect.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/DialectInterface.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "mlir/IR/Operation.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/ManagedStatic.h"
20 #include "llvm/Support/Regex.h"
21 
22 #define DEBUG_TYPE "dialect"
23 
24 using namespace mlir;
25 using namespace detail;
26 
27 //===----------------------------------------------------------------------===//
28 // DialectRegistry
29 //===----------------------------------------------------------------------===//
30 
31 DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
32 
33 void DialectRegistry::addDialectInterface(
34     StringRef dialectName, TypeID interfaceTypeID,
35     DialectInterfaceAllocatorFunction allocator) {
36   assert(allocator && "unexpected null interface allocation function");
37   auto it = registry.find(dialectName.str());
38   assert(it != registry.end() &&
39          "adding an interface for an unregistered dialect");
40 
41   // Bail out if the interface with the given ID is already in the registry for
42   // the given dialect. We expect a small number (dozens) of interfaces so a
43   // linear search is fine here.
44   auto &ifaces = interfaces[it->second.first];
45   for (const auto &kvp : ifaces.dialectInterfaces) {
46     if (kvp.first == interfaceTypeID) {
47       LLVM_DEBUG(llvm::dbgs()
48                  << "[" DEBUG_TYPE
49                     "] repeated interface registration for dialect "
50                  << dialectName);
51       return;
52     }
53   }
54 
55   ifaces.dialectInterfaces.emplace_back(interfaceTypeID, allocator);
56 }
57 
58 void DialectRegistry::addObjectInterface(
59     StringRef dialectName, TypeID objectID, TypeID interfaceTypeID,
60     ObjectInterfaceAllocatorFunction allocator) {
61   assert(allocator && "unexpected null interface allocation function");
62 
63   auto it = registry.find(dialectName.str());
64   assert(it != registry.end() &&
65          "adding an interface for an op from an unregistered dialect");
66 
67   auto dialectID = it->second.first;
68   auto &ifaces = interfaces[dialectID];
69 
70   for (const auto &info : ifaces.objectInterfaces) {
71     if (std::get<0>(info) == objectID && std::get<1>(info) == interfaceTypeID) {
72       LLVM_DEBUG(llvm::dbgs()
73                  << "[" DEBUG_TYPE
74                     "] repeated interface object interface registration");
75       return;
76     }
77   }
78 
79   ifaces.objectInterfaces.emplace_back(objectID, interfaceTypeID, allocator);
80 }
81 
82 DialectAllocatorFunctionRef
83 DialectRegistry::getDialectAllocator(StringRef name) const {
84   auto it = registry.find(name.str());
85   if (it == registry.end())
86     return nullptr;
87   return it->second.second;
88 }
89 
90 void DialectRegistry::insert(TypeID typeID, StringRef name,
91                              DialectAllocatorFunction ctor) {
92   auto inserted = registry.insert(
93       std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
94   if (!inserted.second && inserted.first->second.first != typeID) {
95     llvm::report_fatal_error(
96         "Trying to register different dialects for the same namespace: " +
97         name);
98   }
99 }
100 
101 void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
102   auto it = interfaces.find(dialect->getTypeID());
103   if (it == interfaces.end())
104     return;
105 
106   // Add an interface if it is not already present.
107   for (const auto &kvp : it->getSecond().dialectInterfaces) {
108     if (dialect->getRegisteredInterface(kvp.first))
109       continue;
110     dialect->addInterface(kvp.second(dialect));
111   }
112 
113   // Add attribute, operation and type interfaces.
114   for (const auto &info : it->getSecond().objectInterfaces)
115     std::get<2>(info)(dialect->getContext());
116 }
117 
118 //===----------------------------------------------------------------------===//
119 // Dialect
120 //===----------------------------------------------------------------------===//
121 
122 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
123     : name(name), dialectID(id), context(context) {
124   assert(isValidNamespace(name) && "invalid dialect namespace");
125 }
126 
127 Dialect::~Dialect() {}
128 
129 /// Verify an attribute from this dialect on the argument at 'argIndex' for
130 /// the region at 'regionIndex' on the given operation. Returns failure if
131 /// the verification failed, success otherwise. This hook may optionally be
132 /// invoked from any operation containing a region.
133 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
134                                                 NamedAttribute) {
135   return success();
136 }
137 
138 /// Verify an attribute from this dialect on the result at 'resultIndex' for
139 /// the region at 'regionIndex' on the given operation. Returns failure if
140 /// the verification failed, success otherwise. This hook may optionally be
141 /// invoked from any operation containing a region.
142 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
143                                                    unsigned, NamedAttribute) {
144   return success();
145 }
146 
147 /// Parse an attribute registered to this dialect.
148 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
149   parser.emitError(parser.getNameLoc())
150       << "dialect '" << getNamespace()
151       << "' provides no attribute parsing hook";
152   return Attribute();
153 }
154 
155 /// Parse a type registered to this dialect.
156 Type Dialect::parseType(DialectAsmParser &parser) const {
157   // If this dialect allows unknown types, then represent this with OpaqueType.
158   if (allowsUnknownTypes()) {
159     Identifier ns = Identifier::get(getNamespace(), getContext());
160     return OpaqueType::get(ns, parser.getFullSymbolSpec());
161   }
162 
163   parser.emitError(parser.getNameLoc())
164       << "dialect '" << getNamespace() << "' provides no type parsing hook";
165   return Type();
166 }
167 
168 Optional<Dialect::ParseOpHook>
169 Dialect::getParseOperationHook(StringRef opName) const {
170   return None;
171 }
172 
173 llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
174 Dialect::getOperationPrinter(Operation *op) const {
175   assert(op->getDialect() == this &&
176          "Dialect hook invoked on non-dialect owned operation");
177   return nullptr;
178 }
179 
180 /// Utility function that returns if the given string is a valid dialect
181 /// namespace.
182 bool Dialect::isValidNamespace(StringRef str) {
183   llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
184   return dialectNameRegex.match(str);
185 }
186 
187 /// Register a set of dialect interfaces with this dialect instance.
188 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
189   auto it = registeredInterfaces.try_emplace(interface->getID(),
190                                              std::move(interface));
191   (void)it;
192   assert(it.second && "interface kind has already been registered");
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // Dialect Interface
197 //===----------------------------------------------------------------------===//
198 
199 DialectInterface::~DialectInterface() {}
200 
201 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
202     MLIRContext *ctx, TypeID interfaceKind) {
203   for (auto *dialect : ctx->getLoadedDialects()) {
204     if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
205       interfaces.insert(interface);
206       orderedInterfaces.push_back(interface);
207     }
208   }
209 }
210 
211 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {}
212 
213 /// Get the interface for the dialect of given operation, or null if one
214 /// is not registered.
215 const DialectInterface *
216 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
217   return getInterfaceFor(op->getDialect());
218 }
219