xref: /llvm-project/mlir/lib/IR/ExtensibleDialect.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- ExtensibleDialect.cpp - Extensible dialect ---------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/ExtensibleDialect.h"
10 #include "mlir/IR/AttributeSupport.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "mlir/IR/OperationSupport.h"
13 #include "mlir/IR/StorageUniquerSupport.h"
14 
15 using namespace mlir;
16 
17 //===----------------------------------------------------------------------===//
18 // Dynamic types and attributes shared functions
19 //===----------------------------------------------------------------------===//
20 
21 /// Default parser for dynamic attribute or type parameters.
22 /// Parse in the format '(<>)?' or '<attr (,attr)*>'.
23 static LogicalResult
typeOrAttrParser(AsmParser & parser,SmallVectorImpl<Attribute> & parsedParams)24 typeOrAttrParser(AsmParser &parser, SmallVectorImpl<Attribute> &parsedParams) {
25   // No parameters
26   if (parser.parseOptionalLess() || !parser.parseOptionalGreater())
27     return success();
28 
29   Attribute attr;
30   if (parser.parseAttribute(attr))
31     return failure();
32   parsedParams.push_back(attr);
33 
34   while (parser.parseOptionalGreater()) {
35     Attribute attr;
36     if (parser.parseComma() || parser.parseAttribute(attr))
37       return failure();
38     parsedParams.push_back(attr);
39   }
40 
41   return success();
42 }
43 
44 /// Default printer for dynamic attribute or type parameters.
45 /// Print in the format '(<>)?' or '<attr (,attr)*>'.
typeOrAttrPrinter(AsmPrinter & printer,ArrayRef<Attribute> params)46 static void typeOrAttrPrinter(AsmPrinter &printer, ArrayRef<Attribute> params) {
47   if (params.empty())
48     return;
49 
50   printer << "<";
51   interleaveComma(params, printer.getStream());
52   printer << ">";
53 }
54 
55 //===----------------------------------------------------------------------===//
56 // Dynamic type
57 //===----------------------------------------------------------------------===//
58 
59 std::unique_ptr<DynamicTypeDefinition>
get(StringRef name,ExtensibleDialect * dialect,VerifierFn && verifier)60 DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect,
61                            VerifierFn &&verifier) {
62   return DynamicTypeDefinition::get(name, dialect, std::move(verifier),
63                                     typeOrAttrParser, typeOrAttrPrinter);
64 }
65 
66 std::unique_ptr<DynamicTypeDefinition>
get(StringRef name,ExtensibleDialect * dialect,VerifierFn && verifier,ParserFn && parser,PrinterFn && printer)67 DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect,
68                            VerifierFn &&verifier, ParserFn &&parser,
69                            PrinterFn &&printer) {
70   return std::unique_ptr<DynamicTypeDefinition>(
71       new DynamicTypeDefinition(name, dialect, std::move(verifier),
72                                 std::move(parser), std::move(printer)));
73 }
74 
DynamicTypeDefinition(StringRef nameRef,ExtensibleDialect * dialect,VerifierFn && verifier,ParserFn && parser,PrinterFn && printer)75 DynamicTypeDefinition::DynamicTypeDefinition(StringRef nameRef,
76                                              ExtensibleDialect *dialect,
77                                              VerifierFn &&verifier,
78                                              ParserFn &&parser,
79                                              PrinterFn &&printer)
80     : name(nameRef), dialect(dialect), verifier(std::move(verifier)),
81       parser(std::move(parser)), printer(std::move(printer)),
82       ctx(dialect->getContext()) {}
83 
DynamicTypeDefinition(ExtensibleDialect * dialect,StringRef nameRef)84 DynamicTypeDefinition::DynamicTypeDefinition(ExtensibleDialect *dialect,
85                                              StringRef nameRef)
86     : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {}
87 
registerInTypeUniquer()88 void DynamicTypeDefinition::registerInTypeUniquer() {
89   detail::TypeUniquer::registerType<DynamicType>(&getContext(), getTypeID());
90 }
91 
92 namespace mlir {
93 namespace detail {
94 /// Storage of DynamicType.
95 /// Contains a pointer to the type definition and type parameters.
96 struct DynamicTypeStorage : public TypeStorage {
97 
98   using KeyTy = std::pair<DynamicTypeDefinition *, ArrayRef<Attribute>>;
99 
DynamicTypeStoragemlir::detail::DynamicTypeStorage100   explicit DynamicTypeStorage(DynamicTypeDefinition *typeDef,
101                               ArrayRef<Attribute> params)
102       : typeDef(typeDef), params(params) {}
103 
operator ==mlir::detail::DynamicTypeStorage104   bool operator==(const KeyTy &key) const {
105     return typeDef == key.first && params == key.second;
106   }
107 
hashKeymlir::detail::DynamicTypeStorage108   static llvm::hash_code hashKey(const KeyTy &key) {
109     return llvm::hash_value(key);
110   }
111 
constructmlir::detail::DynamicTypeStorage112   static DynamicTypeStorage *construct(TypeStorageAllocator &alloc,
113                                        const KeyTy &key) {
114     return new (alloc.allocate<DynamicTypeStorage>())
115         DynamicTypeStorage(key.first, alloc.copyInto(key.second));
116   }
117 
118   /// Definition of the type.
119   DynamicTypeDefinition *typeDef;
120 
121   /// The type parameters.
122   ArrayRef<Attribute> params;
123 };
124 } // namespace detail
125 } // namespace mlir
126 
get(DynamicTypeDefinition * typeDef,ArrayRef<Attribute> params)127 DynamicType DynamicType::get(DynamicTypeDefinition *typeDef,
128                              ArrayRef<Attribute> params) {
129   auto &ctx = typeDef->getContext();
130   auto emitError = detail::getDefaultDiagnosticEmitFn(&ctx);
131   assert(succeeded(typeDef->verify(emitError, params)));
132   return detail::TypeUniquer::getWithTypeID<DynamicType>(
133       &ctx, typeDef->getTypeID(), typeDef, params);
134 }
135 
136 DynamicType
getChecked(function_ref<InFlightDiagnostic ()> emitError,DynamicTypeDefinition * typeDef,ArrayRef<Attribute> params)137 DynamicType::getChecked(function_ref<InFlightDiagnostic()> emitError,
138                         DynamicTypeDefinition *typeDef,
139                         ArrayRef<Attribute> params) {
140   if (failed(typeDef->verify(emitError, params)))
141     return {};
142   auto &ctx = typeDef->getContext();
143   return detail::TypeUniquer::getWithTypeID<DynamicType>(
144       &ctx, typeDef->getTypeID(), typeDef, params);
145 }
146 
getTypeDef()147 DynamicTypeDefinition *DynamicType::getTypeDef() { return getImpl()->typeDef; }
148 
getParams()149 ArrayRef<Attribute> DynamicType::getParams() { return getImpl()->params; }
150 
classof(Type type)151 bool DynamicType::classof(Type type) {
152   return type.hasTrait<TypeTrait::IsDynamicType>();
153 }
154 
parse(AsmParser & parser,DynamicTypeDefinition * typeDef,DynamicType & parsedType)155 ParseResult DynamicType::parse(AsmParser &parser,
156                                DynamicTypeDefinition *typeDef,
157                                DynamicType &parsedType) {
158   SmallVector<Attribute> params;
159   if (failed(typeDef->parser(parser, params)))
160     return failure();
161   parsedType = parser.getChecked<DynamicType>(typeDef, params);
162   if (!parsedType)
163     return failure();
164   return success();
165 }
166 
print(AsmPrinter & printer)167 void DynamicType::print(AsmPrinter &printer) {
168   printer << getTypeDef()->getName();
169   getTypeDef()->printer(printer, getParams());
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // Dynamic attribute
174 //===----------------------------------------------------------------------===//
175 
176 std::unique_ptr<DynamicAttrDefinition>
get(StringRef name,ExtensibleDialect * dialect,VerifierFn && verifier)177 DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect,
178                            VerifierFn &&verifier) {
179   return DynamicAttrDefinition::get(name, dialect, std::move(verifier),
180                                     typeOrAttrParser, typeOrAttrPrinter);
181 }
182 
183 std::unique_ptr<DynamicAttrDefinition>
get(StringRef name,ExtensibleDialect * dialect,VerifierFn && verifier,ParserFn && parser,PrinterFn && printer)184 DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect,
185                            VerifierFn &&verifier, ParserFn &&parser,
186                            PrinterFn &&printer) {
187   return std::unique_ptr<DynamicAttrDefinition>(
188       new DynamicAttrDefinition(name, dialect, std::move(verifier),
189                                 std::move(parser), std::move(printer)));
190 }
191 
DynamicAttrDefinition(StringRef nameRef,ExtensibleDialect * dialect,VerifierFn && verifier,ParserFn && parser,PrinterFn && printer)192 DynamicAttrDefinition::DynamicAttrDefinition(StringRef nameRef,
193                                              ExtensibleDialect *dialect,
194                                              VerifierFn &&verifier,
195                                              ParserFn &&parser,
196                                              PrinterFn &&printer)
197     : name(nameRef), dialect(dialect), verifier(std::move(verifier)),
198       parser(std::move(parser)), printer(std::move(printer)),
199       ctx(dialect->getContext()) {}
200 
DynamicAttrDefinition(ExtensibleDialect * dialect,StringRef nameRef)201 DynamicAttrDefinition::DynamicAttrDefinition(ExtensibleDialect *dialect,
202                                              StringRef nameRef)
203     : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {}
204 
registerInAttrUniquer()205 void DynamicAttrDefinition::registerInAttrUniquer() {
206   detail::AttributeUniquer::registerAttribute<DynamicAttr>(&getContext(),
207                                                            getTypeID());
208 }
209 
210 namespace mlir {
211 namespace detail {
212 /// Storage of DynamicAttr.
213 /// Contains a pointer to the attribute definition and attribute parameters.
214 struct DynamicAttrStorage : public AttributeStorage {
215   using KeyTy = std::pair<DynamicAttrDefinition *, ArrayRef<Attribute>>;
216 
DynamicAttrStoragemlir::detail::DynamicAttrStorage217   explicit DynamicAttrStorage(DynamicAttrDefinition *attrDef,
218                               ArrayRef<Attribute> params)
219       : attrDef(attrDef), params(params) {}
220 
operator ==mlir::detail::DynamicAttrStorage221   bool operator==(const KeyTy &key) const {
222     return attrDef == key.first && params == key.second;
223   }
224 
hashKeymlir::detail::DynamicAttrStorage225   static llvm::hash_code hashKey(const KeyTy &key) {
226     return llvm::hash_value(key);
227   }
228 
constructmlir::detail::DynamicAttrStorage229   static DynamicAttrStorage *construct(AttributeStorageAllocator &alloc,
230                                        const KeyTy &key) {
231     return new (alloc.allocate<DynamicAttrStorage>())
232         DynamicAttrStorage(key.first, alloc.copyInto(key.second));
233   }
234 
235   /// Definition of the type.
236   DynamicAttrDefinition *attrDef;
237 
238   /// The type parameters.
239   ArrayRef<Attribute> params;
240 };
241 } // namespace detail
242 } // namespace mlir
243 
get(DynamicAttrDefinition * attrDef,ArrayRef<Attribute> params)244 DynamicAttr DynamicAttr::get(DynamicAttrDefinition *attrDef,
245                              ArrayRef<Attribute> params) {
246   auto &ctx = attrDef->getContext();
247   return detail::AttributeUniquer::getWithTypeID<DynamicAttr>(
248       &ctx, attrDef->getTypeID(), attrDef, params);
249 }
250 
251 DynamicAttr
getChecked(function_ref<InFlightDiagnostic ()> emitError,DynamicAttrDefinition * attrDef,ArrayRef<Attribute> params)252 DynamicAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
253                         DynamicAttrDefinition *attrDef,
254                         ArrayRef<Attribute> params) {
255   if (failed(attrDef->verify(emitError, params)))
256     return {};
257   return get(attrDef, params);
258 }
259 
getAttrDef()260 DynamicAttrDefinition *DynamicAttr::getAttrDef() { return getImpl()->attrDef; }
261 
getParams()262 ArrayRef<Attribute> DynamicAttr::getParams() { return getImpl()->params; }
263 
classof(Attribute attr)264 bool DynamicAttr::classof(Attribute attr) {
265   return attr.hasTrait<AttributeTrait::IsDynamicAttr>();
266 }
267 
parse(AsmParser & parser,DynamicAttrDefinition * attrDef,DynamicAttr & parsedAttr)268 ParseResult DynamicAttr::parse(AsmParser &parser,
269                                DynamicAttrDefinition *attrDef,
270                                DynamicAttr &parsedAttr) {
271   SmallVector<Attribute> params;
272   if (failed(attrDef->parser(parser, params)))
273     return failure();
274   parsedAttr = parser.getChecked<DynamicAttr>(attrDef, params);
275   if (!parsedAttr)
276     return failure();
277   return success();
278 }
279 
print(AsmPrinter & printer)280 void DynamicAttr::print(AsmPrinter &printer) {
281   printer << getAttrDef()->getName();
282   getAttrDef()->printer(printer, getParams());
283 }
284 
285 //===----------------------------------------------------------------------===//
286 // Dynamic operation
287 //===----------------------------------------------------------------------===//
288 
DynamicOpDefinition(StringRef name,ExtensibleDialect * dialect,OperationName::VerifyInvariantsFn && verifyFn,OperationName::VerifyRegionInvariantsFn && verifyRegionFn,OperationName::ParseAssemblyFn && parseFn,OperationName::PrintAssemblyFn && printFn,OperationName::FoldHookFn && foldHookFn,GetCanonicalizationPatternsFn && getCanonicalizationPatternsFn,OperationName::PopulateDefaultAttrsFn && populateDefaultAttrsFn)289 DynamicOpDefinition::DynamicOpDefinition(
290     StringRef name, ExtensibleDialect *dialect,
291     OperationName::VerifyInvariantsFn &&verifyFn,
292     OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
293     OperationName::ParseAssemblyFn &&parseFn,
294     OperationName::PrintAssemblyFn &&printFn,
295     OperationName::FoldHookFn &&foldHookFn,
296     GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
297     OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn)
298     : Impl(StringAttr::get(dialect->getContext(),
299                            (dialect->getNamespace() + "." + name).str()),
300            dialect, dialect->allocateTypeID(),
301            /*interfaceMap=*/detail::InterfaceMap()),
302       verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)),
303       parseFn(std::move(parseFn)), printFn(std::move(printFn)),
304       foldHookFn(std::move(foldHookFn)),
305       getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)),
306       populateDefaultAttrsFn(std::move(populateDefaultAttrsFn)) {
307   typeID = dialect->allocateTypeID();
308 }
309 
get(StringRef name,ExtensibleDialect * dialect,OperationName::VerifyInvariantsFn && verifyFn,OperationName::VerifyRegionInvariantsFn && verifyRegionFn)310 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
311     StringRef name, ExtensibleDialect *dialect,
312     OperationName::VerifyInvariantsFn &&verifyFn,
313     OperationName::VerifyRegionInvariantsFn &&verifyRegionFn) {
314   auto parseFn = [](OpAsmParser &parser, OperationState &result) {
315     return parser.emitError(
316         parser.getCurrentLocation(),
317         "dynamic operation do not define any parser function");
318   };
319 
320   auto printFn = [](Operation *op, OpAsmPrinter &printer, StringRef) {
321     printer.printGenericOp(op);
322   };
323 
324   return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
325                                   std::move(verifyRegionFn), std::move(parseFn),
326                                   std::move(printFn));
327 }
328 
get(StringRef name,ExtensibleDialect * dialect,OperationName::VerifyInvariantsFn && verifyFn,OperationName::VerifyRegionInvariantsFn && verifyRegionFn,OperationName::ParseAssemblyFn && parseFn,OperationName::PrintAssemblyFn && printFn)329 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
330     StringRef name, ExtensibleDialect *dialect,
331     OperationName::VerifyInvariantsFn &&verifyFn,
332     OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
333     OperationName::ParseAssemblyFn &&parseFn,
334     OperationName::PrintAssemblyFn &&printFn) {
335   auto foldHookFn = [](Operation *op, ArrayRef<Attribute> operands,
336                        SmallVectorImpl<OpFoldResult> &results) {
337     return failure();
338   };
339 
340   auto getCanonicalizationPatternsFn = [](RewritePatternSet &, MLIRContext *) {
341   };
342 
343   auto populateDefaultAttrsFn = [](const OperationName &, NamedAttrList &) {};
344 
345   return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
346                                   std::move(verifyRegionFn), std::move(parseFn),
347                                   std::move(printFn), std::move(foldHookFn),
348                                   std::move(getCanonicalizationPatternsFn),
349                                   std::move(populateDefaultAttrsFn));
350 }
351 
get(StringRef name,ExtensibleDialect * dialect,OperationName::VerifyInvariantsFn && verifyFn,OperationName::VerifyInvariantsFn && verifyRegionFn,OperationName::ParseAssemblyFn && parseFn,OperationName::PrintAssemblyFn && printFn,OperationName::FoldHookFn && foldHookFn,GetCanonicalizationPatternsFn && getCanonicalizationPatternsFn,OperationName::PopulateDefaultAttrsFn && populateDefaultAttrsFn)352 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
353     StringRef name, ExtensibleDialect *dialect,
354     OperationName::VerifyInvariantsFn &&verifyFn,
355     OperationName::VerifyInvariantsFn &&verifyRegionFn,
356     OperationName::ParseAssemblyFn &&parseFn,
357     OperationName::PrintAssemblyFn &&printFn,
358     OperationName::FoldHookFn &&foldHookFn,
359     GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
360     OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn) {
361   return std::unique_ptr<DynamicOpDefinition>(new DynamicOpDefinition(
362       name, dialect, std::move(verifyFn), std::move(verifyRegionFn),
363       std::move(parseFn), std::move(printFn), std::move(foldHookFn),
364       std::move(getCanonicalizationPatternsFn),
365       std::move(populateDefaultAttrsFn)));
366 }
367 
368 //===----------------------------------------------------------------------===//
369 // Extensible dialect
370 //===----------------------------------------------------------------------===//
371 
372 namespace {
373 /// Interface that can only be implemented by extensible dialects.
374 /// The interface is used to check if a dialect is extensible or not.
375 class IsExtensibleDialect : public DialectInterface::Base<IsExtensibleDialect> {
376 public:
IsExtensibleDialect(Dialect * dialect)377   IsExtensibleDialect(Dialect *dialect) : Base(dialect) {}
378 
379   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsExtensibleDialect)
380 };
381 } // namespace
382 
ExtensibleDialect(StringRef name,MLIRContext * ctx,TypeID typeID)383 ExtensibleDialect::ExtensibleDialect(StringRef name, MLIRContext *ctx,
384                                      TypeID typeID)
385     : Dialect(name, ctx, typeID) {
386   addInterfaces<IsExtensibleDialect>();
387 }
388 
registerDynamicType(std::unique_ptr<DynamicTypeDefinition> && type)389 void ExtensibleDialect::registerDynamicType(
390     std::unique_ptr<DynamicTypeDefinition> &&type) {
391   DynamicTypeDefinition *typePtr = type.get();
392   TypeID typeID = type->getTypeID();
393   StringRef name = type->getName();
394   ExtensibleDialect *dialect = type->getDialect();
395 
396   assert(dialect == this &&
397          "trying to register a dynamic type in the wrong dialect");
398 
399   // If a type with the same name is already defined, fail.
400   auto registered = dynTypes.try_emplace(typeID, std::move(type)).second;
401   (void)registered;
402   assert(registered && "type TypeID was not unique");
403 
404   registered = nameToDynTypes.insert({name, typePtr}).second;
405   (void)registered;
406   assert(registered &&
407          "Trying to create a new dynamic type with an existing name");
408 
409   // The StringAttr allocates the type name StringRef for the duration of the
410   // MLIR context.
411   MLIRContext *ctx = getContext();
412   auto nameAttr =
413       StringAttr::get(ctx, getNamespace() + "." + typePtr->getName());
414 
415   auto abstractType = AbstractType::get(
416       *dialect, DynamicAttr::getInterfaceMap(), DynamicType::getHasTraitFn(),
417       DynamicType::getWalkImmediateSubElementsFn(),
418       DynamicType::getReplaceImmediateSubElementsFn(), typeID, nameAttr);
419 
420   /// Add the type to the dialect and the type uniquer.
421   addType(typeID, std::move(abstractType));
422   typePtr->registerInTypeUniquer();
423 }
424 
registerDynamicAttr(std::unique_ptr<DynamicAttrDefinition> && attr)425 void ExtensibleDialect::registerDynamicAttr(
426     std::unique_ptr<DynamicAttrDefinition> &&attr) {
427   auto *attrPtr = attr.get();
428   auto typeID = attr->getTypeID();
429   auto name = attr->getName();
430   auto *dialect = attr->getDialect();
431 
432   assert(dialect == this &&
433          "trying to register a dynamic attribute in the wrong dialect");
434 
435   // If an attribute with the same name is already defined, fail.
436   auto registered = dynAttrs.try_emplace(typeID, std::move(attr)).second;
437   (void)registered;
438   assert(registered && "attribute TypeID was not unique");
439 
440   registered = nameToDynAttrs.insert({name, attrPtr}).second;
441   (void)registered;
442   assert(registered &&
443          "Trying to create a new dynamic attribute with an existing name");
444 
445   // The StringAttr allocates the attribute name StringRef for the duration of
446   // the MLIR context.
447   MLIRContext *ctx = getContext();
448   auto nameAttr =
449       StringAttr::get(ctx, getNamespace() + "." + attrPtr->getName());
450 
451   auto abstractAttr = AbstractAttribute::get(
452       *dialect, DynamicAttr::getInterfaceMap(), DynamicAttr::getHasTraitFn(),
453       DynamicAttr::getWalkImmediateSubElementsFn(),
454       DynamicAttr::getReplaceImmediateSubElementsFn(), typeID, nameAttr);
455 
456   /// Add the type to the dialect and the type uniquer.
457   addAttribute(typeID, std::move(abstractAttr));
458   attrPtr->registerInAttrUniquer();
459 }
460 
registerDynamicOp(std::unique_ptr<DynamicOpDefinition> && op)461 void ExtensibleDialect::registerDynamicOp(
462     std::unique_ptr<DynamicOpDefinition> &&op) {
463   assert(op->dialect == this &&
464          "trying to register a dynamic op in the wrong dialect");
465   RegisteredOperationName::insert(std::move(op), /*attrNames=*/{});
466 }
467 
classof(const Dialect * dialect)468 bool ExtensibleDialect::classof(const Dialect *dialect) {
469   return const_cast<Dialect *>(dialect)
470       ->getRegisteredInterface<IsExtensibleDialect>();
471 }
472 
parseOptionalDynamicType(StringRef typeName,AsmParser & parser,Type & resultType) const473 OptionalParseResult ExtensibleDialect::parseOptionalDynamicType(
474     StringRef typeName, AsmParser &parser, Type &resultType) const {
475   DynamicTypeDefinition *typeDef = lookupTypeDefinition(typeName);
476   if (!typeDef)
477     return std::nullopt;
478 
479   DynamicType dynType;
480   if (DynamicType::parse(parser, typeDef, dynType))
481     return failure();
482   resultType = dynType;
483   return success();
484 }
485 
printIfDynamicType(Type type,AsmPrinter & printer)486 LogicalResult ExtensibleDialect::printIfDynamicType(Type type,
487                                                     AsmPrinter &printer) {
488   if (auto dynType = llvm::dyn_cast<DynamicType>(type)) {
489     dynType.print(printer);
490     return success();
491   }
492   return failure();
493 }
494 
parseOptionalDynamicAttr(StringRef attrName,AsmParser & parser,Attribute & resultAttr) const495 OptionalParseResult ExtensibleDialect::parseOptionalDynamicAttr(
496     StringRef attrName, AsmParser &parser, Attribute &resultAttr) const {
497   DynamicAttrDefinition *attrDef = lookupAttrDefinition(attrName);
498   if (!attrDef)
499     return std::nullopt;
500 
501   DynamicAttr dynAttr;
502   if (DynamicAttr::parse(parser, attrDef, dynAttr))
503     return failure();
504   resultAttr = dynAttr;
505   return success();
506 }
507 
printIfDynamicAttr(Attribute attribute,AsmPrinter & printer)508 LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute,
509                                                     AsmPrinter &printer) {
510   if (auto dynAttr = llvm::dyn_cast<DynamicAttr>(attribute)) {
511     dynAttr.print(printer);
512     return success();
513   }
514   return failure();
515 }
516 
517 //===----------------------------------------------------------------------===//
518 // Dynamic dialect
519 //===----------------------------------------------------------------------===//
520 
521 namespace {
522 /// Interface that can only be implemented by extensible dialects.
523 /// The interface is used to check if a dialect is extensible or not.
524 class IsDynamicDialect : public DialectInterface::Base<IsDynamicDialect> {
525 public:
IsDynamicDialect(Dialect * dialect)526   IsDynamicDialect(Dialect *dialect) : Base(dialect) {}
527 
528   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsDynamicDialect)
529 };
530 } // namespace
531 
DynamicDialect(StringRef name,MLIRContext * ctx)532 DynamicDialect::DynamicDialect(StringRef name, MLIRContext *ctx)
533     : SelfOwningTypeID(),
534       ExtensibleDialect(name, ctx, SelfOwningTypeID::getTypeID()) {
535   addInterfaces<IsDynamicDialect>();
536 }
537 
classof(const Dialect * dialect)538 bool DynamicDialect::classof(const Dialect *dialect) {
539   return const_cast<Dialect *>(dialect)
540       ->getRegisteredInterface<IsDynamicDialect>();
541 }
542 
parseType(DialectAsmParser & parser) const543 Type DynamicDialect::parseType(DialectAsmParser &parser) const {
544   auto loc = parser.getCurrentLocation();
545   StringRef typeTag;
546   if (failed(parser.parseKeyword(&typeTag)))
547     return Type();
548 
549   {
550     Type dynType;
551     auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType);
552     if (parseResult.has_value()) {
553       if (succeeded(parseResult.value()))
554         return dynType;
555       return Type();
556     }
557   }
558 
559   parser.emitError(loc, "expected dynamic type");
560   return Type();
561 }
562 
printType(Type type,DialectAsmPrinter & printer) const563 void DynamicDialect::printType(Type type, DialectAsmPrinter &printer) const {
564   auto wasDynamic = printIfDynamicType(type, printer);
565   (void)wasDynamic;
566   assert(succeeded(wasDynamic) &&
567          "non-dynamic type defined in dynamic dialect");
568 }
569 
parseAttribute(DialectAsmParser & parser,Type type) const570 Attribute DynamicDialect::parseAttribute(DialectAsmParser &parser,
571                                          Type type) const {
572   auto loc = parser.getCurrentLocation();
573   StringRef typeTag;
574   if (failed(parser.parseKeyword(&typeTag)))
575     return Attribute();
576 
577   {
578     Attribute dynAttr;
579     auto parseResult = parseOptionalDynamicAttr(typeTag, parser, dynAttr);
580     if (parseResult.has_value()) {
581       if (succeeded(parseResult.value()))
582         return dynAttr;
583       return Attribute();
584     }
585   }
586 
587   parser.emitError(loc, "expected dynamic attribute");
588   return Attribute();
589 }
printAttribute(Attribute attr,DialectAsmPrinter & printer) const590 void DynamicDialect::printAttribute(Attribute attr,
591                                     DialectAsmPrinter &printer) const {
592   auto wasDynamic = printIfDynamicAttr(attr, printer);
593   (void)wasDynamic;
594   assert(succeeded(wasDynamic) &&
595          "non-dynamic attribute defined in dynamic dialect");
596 }
597