xref: /llvm-project/mlir/lib/IR/Dialect.cpp (revision a0c776fc94d3179822c95dcb9f79b344e13f069b)
1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/Diagnostics.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "mlir/IR/DialectInterface.h"
13 #include "mlir/IR/MLIRContext.h"
14 #include "mlir/IR/Operation.h"
15 #include "llvm/ADT/MapVector.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/Debug.h"
18 #include "llvm/Support/ManagedStatic.h"
19 #include "llvm/Support/Regex.h"
20 
21 #define DEBUG_TYPE "dialect"
22 
23 using namespace mlir;
24 using namespace detail;
25 
26 DialectAsmParser::~DialectAsmParser() {}
27 
28 //===----------------------------------------------------------------------===//
29 // DialectRegistry
30 //===----------------------------------------------------------------------===//
31 
32 void DialectRegistry::addDialectInterface(
33     StringRef dialectName, TypeID interfaceTypeID,
34     InterfaceAllocatorFunction allocator) {
35   assert(allocator && "unexpected null interface allocation function");
36   auto it = registry.find(dialectName.str());
37   assert(it != registry.end() &&
38          "adding an interface for an unregistered dialect");
39 
40   // Bail out if the interface with the given ID is already in the registry for
41   // the given dialect. We expect a small number (dozens) of interfaces so a
42   // linear search is fine here.
43   auto &dialectInterfaces = interfaces[it->second.first];
44   for (const auto &kvp : dialectInterfaces) {
45     if (kvp.first == interfaceTypeID) {
46       LLVM_DEBUG(llvm::dbgs()
47                  << "[" DEBUG_TYPE
48                     "] repeated interface registration for dialect "
49                  << dialectName);
50       return;
51     }
52   }
53 
54   dialectInterfaces.emplace_back(interfaceTypeID, allocator);
55 }
56 
57 DialectAllocatorFunctionRef
58 DialectRegistry::getDialectAllocator(StringRef name) const {
59   auto it = registry.find(name.str());
60   if (it == registry.end())
61     return nullptr;
62   return it->second.second;
63 }
64 
65 void DialectRegistry::insert(TypeID typeID, StringRef name,
66                              DialectAllocatorFunction ctor) {
67   auto inserted = registry.insert(
68       std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
69   if (!inserted.second && inserted.first->second.first != typeID) {
70     llvm::report_fatal_error(
71         "Trying to register different dialects for the same namespace: " +
72         name);
73   }
74 }
75 
76 void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
77   auto it = interfaces.find(dialect->getTypeID());
78   if (it == interfaces.end())
79     return;
80 
81   // Add an interface if it is not already present.
82   for (const auto &kvp : it->second) {
83     if (dialect->getRegisteredInterface(kvp.first))
84       continue;
85     dialect->addInterface(kvp.second(dialect));
86   }
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // Dialect
91 //===----------------------------------------------------------------------===//
92 
93 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
94     : name(name), dialectID(id), context(context) {
95   assert(isValidNamespace(name) && "invalid dialect namespace");
96 }
97 
98 Dialect::~Dialect() {}
99 
100 /// Verify an attribute from this dialect on the argument at 'argIndex' for
101 /// the region at 'regionIndex' on the given operation. Returns failure if
102 /// the verification failed, success otherwise. This hook may optionally be
103 /// invoked from any operation containing a region.
104 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
105                                                 NamedAttribute) {
106   return success();
107 }
108 
109 /// Verify an attribute from this dialect on the result at 'resultIndex' for
110 /// the region at 'regionIndex' on the given operation. Returns failure if
111 /// the verification failed, success otherwise. This hook may optionally be
112 /// invoked from any operation containing a region.
113 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
114                                                    unsigned, NamedAttribute) {
115   return success();
116 }
117 
118 /// Parse an attribute registered to this dialect.
119 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
120   parser.emitError(parser.getNameLoc())
121       << "dialect '" << getNamespace()
122       << "' provides no attribute parsing hook";
123   return Attribute();
124 }
125 
126 /// Parse a type registered to this dialect.
127 Type Dialect::parseType(DialectAsmParser &parser) const {
128   // If this dialect allows unknown types, then represent this with OpaqueType.
129   if (allowsUnknownTypes()) {
130     Identifier ns = Identifier::get(getNamespace(), getContext());
131     return OpaqueType::get(ns, parser.getFullSymbolSpec());
132   }
133 
134   parser.emitError(parser.getNameLoc())
135       << "dialect '" << getNamespace() << "' provides no type parsing hook";
136   return Type();
137 }
138 
139 Optional<Dialect::ParseOpHook>
140 Dialect::getParseOperationHook(StringRef opName) const {
141   return None;
142 }
143 
144 LogicalResult Dialect::printOperation(Operation *op,
145                                       OpAsmPrinter &printer) const {
146   assert(op->getDialect() == this &&
147          "Dialect hook invoked on non-dialect owned operation");
148   return failure();
149 }
150 
151 /// Utility function that returns if the given string is a valid dialect
152 /// namespace.
153 bool Dialect::isValidNamespace(StringRef str) {
154   if (str.empty())
155     return true;
156   llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
157   return dialectNameRegex.match(str);
158 }
159 
160 /// Register a set of dialect interfaces with this dialect instance.
161 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
162   auto it = registeredInterfaces.try_emplace(interface->getID(),
163                                              std::move(interface));
164   (void)it;
165   assert(it.second && "interface kind has already been registered");
166 }
167 
168 //===----------------------------------------------------------------------===//
169 // Dialect Interface
170 //===----------------------------------------------------------------------===//
171 
172 DialectInterface::~DialectInterface() {}
173 
174 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
175     MLIRContext *ctx, TypeID interfaceKind) {
176   for (auto *dialect : ctx->getLoadedDialects()) {
177     if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
178       interfaces.insert(interface);
179       orderedInterfaces.push_back(interface);
180     }
181   }
182 }
183 
184 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {}
185 
186 /// Get the interface for the dialect of given operation, or null if one
187 /// is not registered.
188 const DialectInterface *
189 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
190   return getInterfaceFor(op->getDialect());
191 }
192