xref: /llvm-project/mlir/lib/IR/Dialect.cpp (revision 77eee5795e2cf753e4400fb089d01018417c4ee0)
1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/BuiltinDialect.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/DialectInterface.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "mlir/IR/Operation.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/ManagedStatic.h"
20 #include "llvm/Support/Regex.h"
21 
22 #define DEBUG_TYPE "dialect"
23 
24 using namespace mlir;
25 using namespace detail;
26 
27 //===----------------------------------------------------------------------===//
28 // Dialect
29 //===----------------------------------------------------------------------===//
30 
31 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
32     : name(name), dialectID(id), context(context) {
33   assert(isValidNamespace(name) && "invalid dialect namespace");
34 }
35 
36 Dialect::~Dialect() = default;
37 
38 /// Verify an attribute from this dialect on the argument at 'argIndex' for
39 /// the region at 'regionIndex' on the given operation. Returns failure if
40 /// the verification failed, success otherwise. This hook may optionally be
41 /// invoked from any operation containing a region.
42 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
43                                                 NamedAttribute) {
44   return success();
45 }
46 
47 /// Verify an attribute from this dialect on the result at 'resultIndex' for
48 /// the region at 'regionIndex' on the given operation. Returns failure if
49 /// the verification failed, success otherwise. This hook may optionally be
50 /// invoked from any operation containing a region.
51 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
52                                                    unsigned, NamedAttribute) {
53   return success();
54 }
55 
56 /// Parse an attribute registered to this dialect.
57 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
58   parser.emitError(parser.getNameLoc())
59       << "dialect '" << getNamespace()
60       << "' provides no attribute parsing hook";
61   return Attribute();
62 }
63 
64 /// Parse a type registered to this dialect.
65 Type Dialect::parseType(DialectAsmParser &parser) const {
66   // If this dialect allows unknown types, then represent this with OpaqueType.
67   if (allowsUnknownTypes()) {
68     StringAttr ns = StringAttr::get(getContext(), getNamespace());
69     return OpaqueType::get(ns, parser.getFullSymbolSpec());
70   }
71 
72   parser.emitError(parser.getNameLoc())
73       << "dialect '" << getNamespace() << "' provides no type parsing hook";
74   return Type();
75 }
76 
77 Optional<Dialect::ParseOpHook>
78 Dialect::getParseOperationHook(StringRef opName) const {
79   return None;
80 }
81 
82 llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
83 Dialect::getOperationPrinter(Operation *op) const {
84   assert(op->getDialect() == this &&
85          "Dialect hook invoked on non-dialect owned operation");
86   return nullptr;
87 }
88 
89 /// Utility function that returns if the given string is a valid dialect
90 /// namespace
91 bool Dialect::isValidNamespace(StringRef str) {
92   llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
93   return dialectNameRegex.match(str);
94 }
95 
96 /// Register a set of dialect interfaces with this dialect instance.
97 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
98   auto it = registeredInterfaces.try_emplace(interface->getID(),
99                                              std::move(interface));
100   (void)it;
101   LLVM_DEBUG({
102     if (!it.second) {
103       llvm::dbgs() << "[" DEBUG_TYPE
104                       "] repeated interface registration for dialect "
105                    << getNamespace();
106     }
107   });
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // Dialect Interface
112 //===----------------------------------------------------------------------===//
113 
114 DialectInterface::~DialectInterface() = default;
115 
116 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
117     MLIRContext *ctx, TypeID interfaceKind) {
118   for (auto *dialect : ctx->getLoadedDialects()) {
119     if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
120       interfaces.insert(interface);
121       orderedInterfaces.push_back(interface);
122     }
123   }
124 }
125 
126 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() = default;
127 
128 /// Get the interface for the dialect of given operation, or null if one
129 /// is not registered.
130 const DialectInterface *
131 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
132   return getInterfaceFor(op->getDialect());
133 }
134 
135 //===----------------------------------------------------------------------===//
136 // DialectExtension
137 //===----------------------------------------------------------------------===//
138 
139 DialectExtensionBase::~DialectExtensionBase() = default;
140 
141 //===----------------------------------------------------------------------===//
142 // DialectRegistry
143 //===----------------------------------------------------------------------===//
144 
145 DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
146 
147 DialectAllocatorFunctionRef
148 DialectRegistry::getDialectAllocator(StringRef name) const {
149   auto it = registry.find(name.str());
150   if (it == registry.end())
151     return nullptr;
152   return it->second.second;
153 }
154 
155 void DialectRegistry::insert(TypeID typeID, StringRef name,
156                              const DialectAllocatorFunction &ctor) {
157   auto inserted = registry.insert(
158       std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
159   if (!inserted.second && inserted.first->second.first != typeID) {
160     llvm::report_fatal_error(
161         "Trying to register different dialects for the same namespace: " +
162         name);
163   }
164 }
165 
166 void DialectRegistry::applyExtensions(Dialect *dialect) const {
167   MLIRContext *ctx = dialect->getContext();
168   StringRef dialectName = dialect->getNamespace();
169 
170   // Functor used to try to apply the given extension.
171   auto applyExtension = [&](const DialectExtensionBase &extension) {
172     ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
173 
174     // Handle the simple case of a single dialect name. In this case, the
175     // required dialect should be the current dialect.
176     if (dialectNames.size() == 1) {
177       if (dialectNames.front() == dialectName)
178         extension.apply(ctx, dialect);
179       return;
180     }
181 
182     // Otherwise, check to see if this extension requires this dialect.
183     const StringRef *nameIt = llvm::find(dialectNames, dialectName);
184     if (nameIt == dialectNames.end())
185       return;
186 
187     // If it does, ensure that all of the other required dialects have been
188     // loaded.
189     SmallVector<Dialect *> requiredDialects;
190     requiredDialects.reserve(dialectNames.size());
191     for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e;
192          ++it) {
193       // The current dialect is known to be loaded.
194       if (it == nameIt) {
195         requiredDialects.push_back(dialect);
196         continue;
197       }
198       // Otherwise, check if it is loaded.
199       Dialect *loadedDialect = ctx->getLoadedDialect(*it);
200       if (!loadedDialect)
201         return;
202       requiredDialects.push_back(loadedDialect);
203     }
204     extension.apply(ctx, requiredDialects);
205   };
206 
207   for (const auto &extension : extensions)
208     applyExtension(*extension);
209 }
210 
211 void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
212   // Functor used to try to apply the given extension.
213   auto applyExtension = [&](const DialectExtensionBase &extension) {
214     ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
215 
216     // Check to see if all of the dialects for this extension are loaded.
217     SmallVector<Dialect *> requiredDialects;
218     requiredDialects.reserve(dialectNames.size());
219     for (StringRef dialectName : dialectNames) {
220       Dialect *loadedDialect = ctx->getLoadedDialect(dialectName);
221       if (!loadedDialect)
222         return;
223       requiredDialects.push_back(loadedDialect);
224     }
225     extension.apply(ctx, requiredDialects);
226   };
227 
228   for (const auto &extension : extensions)
229     applyExtension(*extension);
230 }
231