1 //===- ExtensibleDialect.cpp - Extensible dialect ---------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/ExtensibleDialect.h"
10 #include "mlir/IR/AttributeSupport.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "mlir/IR/OperationSupport.h"
13 #include "mlir/IR/StorageUniquerSupport.h"
14
15 using namespace mlir;
16
17 //===----------------------------------------------------------------------===//
18 // Dynamic types and attributes shared functions
19 //===----------------------------------------------------------------------===//
20
21 /// Default parser for dynamic attribute or type parameters.
22 /// Parse in the format '(<>)?' or '<attr (,attr)*>'.
23 static LogicalResult
typeOrAttrParser(AsmParser & parser,SmallVectorImpl<Attribute> & parsedParams)24 typeOrAttrParser(AsmParser &parser, SmallVectorImpl<Attribute> &parsedParams) {
25 // No parameters
26 if (parser.parseOptionalLess() || !parser.parseOptionalGreater())
27 return success();
28
29 Attribute attr;
30 if (parser.parseAttribute(attr))
31 return failure();
32 parsedParams.push_back(attr);
33
34 while (parser.parseOptionalGreater()) {
35 Attribute attr;
36 if (parser.parseComma() || parser.parseAttribute(attr))
37 return failure();
38 parsedParams.push_back(attr);
39 }
40
41 return success();
42 }
43
44 /// Default printer for dynamic attribute or type parameters.
45 /// Print in the format '(<>)?' or '<attr (,attr)*>'.
typeOrAttrPrinter(AsmPrinter & printer,ArrayRef<Attribute> params)46 static void typeOrAttrPrinter(AsmPrinter &printer, ArrayRef<Attribute> params) {
47 if (params.empty())
48 return;
49
50 printer << "<";
51 interleaveComma(params, printer.getStream());
52 printer << ">";
53 }
54
55 //===----------------------------------------------------------------------===//
56 // Dynamic type
57 //===----------------------------------------------------------------------===//
58
59 std::unique_ptr<DynamicTypeDefinition>
get(StringRef name,ExtensibleDialect * dialect,VerifierFn && verifier)60 DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect,
61 VerifierFn &&verifier) {
62 return DynamicTypeDefinition::get(name, dialect, std::move(verifier),
63 typeOrAttrParser, typeOrAttrPrinter);
64 }
65
66 std::unique_ptr<DynamicTypeDefinition>
get(StringRef name,ExtensibleDialect * dialect,VerifierFn && verifier,ParserFn && parser,PrinterFn && printer)67 DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect,
68 VerifierFn &&verifier, ParserFn &&parser,
69 PrinterFn &&printer) {
70 return std::unique_ptr<DynamicTypeDefinition>(
71 new DynamicTypeDefinition(name, dialect, std::move(verifier),
72 std::move(parser), std::move(printer)));
73 }
74
DynamicTypeDefinition(StringRef nameRef,ExtensibleDialect * dialect,VerifierFn && verifier,ParserFn && parser,PrinterFn && printer)75 DynamicTypeDefinition::DynamicTypeDefinition(StringRef nameRef,
76 ExtensibleDialect *dialect,
77 VerifierFn &&verifier,
78 ParserFn &&parser,
79 PrinterFn &&printer)
80 : name(nameRef), dialect(dialect), verifier(std::move(verifier)),
81 parser(std::move(parser)), printer(std::move(printer)),
82 ctx(dialect->getContext()) {}
83
DynamicTypeDefinition(ExtensibleDialect * dialect,StringRef nameRef)84 DynamicTypeDefinition::DynamicTypeDefinition(ExtensibleDialect *dialect,
85 StringRef nameRef)
86 : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {}
87
registerInTypeUniquer()88 void DynamicTypeDefinition::registerInTypeUniquer() {
89 detail::TypeUniquer::registerType<DynamicType>(&getContext(), getTypeID());
90 }
91
92 namespace mlir {
93 namespace detail {
94 /// Storage of DynamicType.
95 /// Contains a pointer to the type definition and type parameters.
96 struct DynamicTypeStorage : public TypeStorage {
97
98 using KeyTy = std::pair<DynamicTypeDefinition *, ArrayRef<Attribute>>;
99
DynamicTypeStoragemlir::detail::DynamicTypeStorage100 explicit DynamicTypeStorage(DynamicTypeDefinition *typeDef,
101 ArrayRef<Attribute> params)
102 : typeDef(typeDef), params(params) {}
103
operator ==mlir::detail::DynamicTypeStorage104 bool operator==(const KeyTy &key) const {
105 return typeDef == key.first && params == key.second;
106 }
107
hashKeymlir::detail::DynamicTypeStorage108 static llvm::hash_code hashKey(const KeyTy &key) {
109 return llvm::hash_value(key);
110 }
111
constructmlir::detail::DynamicTypeStorage112 static DynamicTypeStorage *construct(TypeStorageAllocator &alloc,
113 const KeyTy &key) {
114 return new (alloc.allocate<DynamicTypeStorage>())
115 DynamicTypeStorage(key.first, alloc.copyInto(key.second));
116 }
117
118 /// Definition of the type.
119 DynamicTypeDefinition *typeDef;
120
121 /// The type parameters.
122 ArrayRef<Attribute> params;
123 };
124 } // namespace detail
125 } // namespace mlir
126
get(DynamicTypeDefinition * typeDef,ArrayRef<Attribute> params)127 DynamicType DynamicType::get(DynamicTypeDefinition *typeDef,
128 ArrayRef<Attribute> params) {
129 auto &ctx = typeDef->getContext();
130 auto emitError = detail::getDefaultDiagnosticEmitFn(&ctx);
131 assert(succeeded(typeDef->verify(emitError, params)));
132 return detail::TypeUniquer::getWithTypeID<DynamicType>(
133 &ctx, typeDef->getTypeID(), typeDef, params);
134 }
135
136 DynamicType
getChecked(function_ref<InFlightDiagnostic ()> emitError,DynamicTypeDefinition * typeDef,ArrayRef<Attribute> params)137 DynamicType::getChecked(function_ref<InFlightDiagnostic()> emitError,
138 DynamicTypeDefinition *typeDef,
139 ArrayRef<Attribute> params) {
140 if (failed(typeDef->verify(emitError, params)))
141 return {};
142 auto &ctx = typeDef->getContext();
143 return detail::TypeUniquer::getWithTypeID<DynamicType>(
144 &ctx, typeDef->getTypeID(), typeDef, params);
145 }
146
getTypeDef()147 DynamicTypeDefinition *DynamicType::getTypeDef() { return getImpl()->typeDef; }
148
getParams()149 ArrayRef<Attribute> DynamicType::getParams() { return getImpl()->params; }
150
classof(Type type)151 bool DynamicType::classof(Type type) {
152 return type.hasTrait<TypeTrait::IsDynamicType>();
153 }
154
parse(AsmParser & parser,DynamicTypeDefinition * typeDef,DynamicType & parsedType)155 ParseResult DynamicType::parse(AsmParser &parser,
156 DynamicTypeDefinition *typeDef,
157 DynamicType &parsedType) {
158 SmallVector<Attribute> params;
159 if (failed(typeDef->parser(parser, params)))
160 return failure();
161 parsedType = parser.getChecked<DynamicType>(typeDef, params);
162 if (!parsedType)
163 return failure();
164 return success();
165 }
166
print(AsmPrinter & printer)167 void DynamicType::print(AsmPrinter &printer) {
168 printer << getTypeDef()->getName();
169 getTypeDef()->printer(printer, getParams());
170 }
171
172 //===----------------------------------------------------------------------===//
173 // Dynamic attribute
174 //===----------------------------------------------------------------------===//
175
176 std::unique_ptr<DynamicAttrDefinition>
get(StringRef name,ExtensibleDialect * dialect,VerifierFn && verifier)177 DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect,
178 VerifierFn &&verifier) {
179 return DynamicAttrDefinition::get(name, dialect, std::move(verifier),
180 typeOrAttrParser, typeOrAttrPrinter);
181 }
182
183 std::unique_ptr<DynamicAttrDefinition>
get(StringRef name,ExtensibleDialect * dialect,VerifierFn && verifier,ParserFn && parser,PrinterFn && printer)184 DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect,
185 VerifierFn &&verifier, ParserFn &&parser,
186 PrinterFn &&printer) {
187 return std::unique_ptr<DynamicAttrDefinition>(
188 new DynamicAttrDefinition(name, dialect, std::move(verifier),
189 std::move(parser), std::move(printer)));
190 }
191
DynamicAttrDefinition(StringRef nameRef,ExtensibleDialect * dialect,VerifierFn && verifier,ParserFn && parser,PrinterFn && printer)192 DynamicAttrDefinition::DynamicAttrDefinition(StringRef nameRef,
193 ExtensibleDialect *dialect,
194 VerifierFn &&verifier,
195 ParserFn &&parser,
196 PrinterFn &&printer)
197 : name(nameRef), dialect(dialect), verifier(std::move(verifier)),
198 parser(std::move(parser)), printer(std::move(printer)),
199 ctx(dialect->getContext()) {}
200
DynamicAttrDefinition(ExtensibleDialect * dialect,StringRef nameRef)201 DynamicAttrDefinition::DynamicAttrDefinition(ExtensibleDialect *dialect,
202 StringRef nameRef)
203 : name(nameRef), dialect(dialect), ctx(dialect->getContext()) {}
204
registerInAttrUniquer()205 void DynamicAttrDefinition::registerInAttrUniquer() {
206 detail::AttributeUniquer::registerAttribute<DynamicAttr>(&getContext(),
207 getTypeID());
208 }
209
210 namespace mlir {
211 namespace detail {
212 /// Storage of DynamicAttr.
213 /// Contains a pointer to the attribute definition and attribute parameters.
214 struct DynamicAttrStorage : public AttributeStorage {
215 using KeyTy = std::pair<DynamicAttrDefinition *, ArrayRef<Attribute>>;
216
DynamicAttrStoragemlir::detail::DynamicAttrStorage217 explicit DynamicAttrStorage(DynamicAttrDefinition *attrDef,
218 ArrayRef<Attribute> params)
219 : attrDef(attrDef), params(params) {}
220
operator ==mlir::detail::DynamicAttrStorage221 bool operator==(const KeyTy &key) const {
222 return attrDef == key.first && params == key.second;
223 }
224
hashKeymlir::detail::DynamicAttrStorage225 static llvm::hash_code hashKey(const KeyTy &key) {
226 return llvm::hash_value(key);
227 }
228
constructmlir::detail::DynamicAttrStorage229 static DynamicAttrStorage *construct(AttributeStorageAllocator &alloc,
230 const KeyTy &key) {
231 return new (alloc.allocate<DynamicAttrStorage>())
232 DynamicAttrStorage(key.first, alloc.copyInto(key.second));
233 }
234
235 /// Definition of the type.
236 DynamicAttrDefinition *attrDef;
237
238 /// The type parameters.
239 ArrayRef<Attribute> params;
240 };
241 } // namespace detail
242 } // namespace mlir
243
get(DynamicAttrDefinition * attrDef,ArrayRef<Attribute> params)244 DynamicAttr DynamicAttr::get(DynamicAttrDefinition *attrDef,
245 ArrayRef<Attribute> params) {
246 auto &ctx = attrDef->getContext();
247 return detail::AttributeUniquer::getWithTypeID<DynamicAttr>(
248 &ctx, attrDef->getTypeID(), attrDef, params);
249 }
250
251 DynamicAttr
getChecked(function_ref<InFlightDiagnostic ()> emitError,DynamicAttrDefinition * attrDef,ArrayRef<Attribute> params)252 DynamicAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
253 DynamicAttrDefinition *attrDef,
254 ArrayRef<Attribute> params) {
255 if (failed(attrDef->verify(emitError, params)))
256 return {};
257 return get(attrDef, params);
258 }
259
getAttrDef()260 DynamicAttrDefinition *DynamicAttr::getAttrDef() { return getImpl()->attrDef; }
261
getParams()262 ArrayRef<Attribute> DynamicAttr::getParams() { return getImpl()->params; }
263
classof(Attribute attr)264 bool DynamicAttr::classof(Attribute attr) {
265 return attr.hasTrait<AttributeTrait::IsDynamicAttr>();
266 }
267
parse(AsmParser & parser,DynamicAttrDefinition * attrDef,DynamicAttr & parsedAttr)268 ParseResult DynamicAttr::parse(AsmParser &parser,
269 DynamicAttrDefinition *attrDef,
270 DynamicAttr &parsedAttr) {
271 SmallVector<Attribute> params;
272 if (failed(attrDef->parser(parser, params)))
273 return failure();
274 parsedAttr = parser.getChecked<DynamicAttr>(attrDef, params);
275 if (!parsedAttr)
276 return failure();
277 return success();
278 }
279
print(AsmPrinter & printer)280 void DynamicAttr::print(AsmPrinter &printer) {
281 printer << getAttrDef()->getName();
282 getAttrDef()->printer(printer, getParams());
283 }
284
285 //===----------------------------------------------------------------------===//
286 // Dynamic operation
287 //===----------------------------------------------------------------------===//
288
DynamicOpDefinition(StringRef name,ExtensibleDialect * dialect,OperationName::VerifyInvariantsFn && verifyFn,OperationName::VerifyRegionInvariantsFn && verifyRegionFn,OperationName::ParseAssemblyFn && parseFn,OperationName::PrintAssemblyFn && printFn,OperationName::FoldHookFn && foldHookFn,GetCanonicalizationPatternsFn && getCanonicalizationPatternsFn,OperationName::PopulateDefaultAttrsFn && populateDefaultAttrsFn)289 DynamicOpDefinition::DynamicOpDefinition(
290 StringRef name, ExtensibleDialect *dialect,
291 OperationName::VerifyInvariantsFn &&verifyFn,
292 OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
293 OperationName::ParseAssemblyFn &&parseFn,
294 OperationName::PrintAssemblyFn &&printFn,
295 OperationName::FoldHookFn &&foldHookFn,
296 GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
297 OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn)
298 : Impl(StringAttr::get(dialect->getContext(),
299 (dialect->getNamespace() + "." + name).str()),
300 dialect, dialect->allocateTypeID(),
301 /*interfaceMap=*/detail::InterfaceMap()),
302 verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)),
303 parseFn(std::move(parseFn)), printFn(std::move(printFn)),
304 foldHookFn(std::move(foldHookFn)),
305 getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)),
306 populateDefaultAttrsFn(std::move(populateDefaultAttrsFn)) {
307 typeID = dialect->allocateTypeID();
308 }
309
get(StringRef name,ExtensibleDialect * dialect,OperationName::VerifyInvariantsFn && verifyFn,OperationName::VerifyRegionInvariantsFn && verifyRegionFn)310 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
311 StringRef name, ExtensibleDialect *dialect,
312 OperationName::VerifyInvariantsFn &&verifyFn,
313 OperationName::VerifyRegionInvariantsFn &&verifyRegionFn) {
314 auto parseFn = [](OpAsmParser &parser, OperationState &result) {
315 return parser.emitError(
316 parser.getCurrentLocation(),
317 "dynamic operation do not define any parser function");
318 };
319
320 auto printFn = [](Operation *op, OpAsmPrinter &printer, StringRef) {
321 printer.printGenericOp(op);
322 };
323
324 return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
325 std::move(verifyRegionFn), std::move(parseFn),
326 std::move(printFn));
327 }
328
get(StringRef name,ExtensibleDialect * dialect,OperationName::VerifyInvariantsFn && verifyFn,OperationName::VerifyRegionInvariantsFn && verifyRegionFn,OperationName::ParseAssemblyFn && parseFn,OperationName::PrintAssemblyFn && printFn)329 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
330 StringRef name, ExtensibleDialect *dialect,
331 OperationName::VerifyInvariantsFn &&verifyFn,
332 OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
333 OperationName::ParseAssemblyFn &&parseFn,
334 OperationName::PrintAssemblyFn &&printFn) {
335 auto foldHookFn = [](Operation *op, ArrayRef<Attribute> operands,
336 SmallVectorImpl<OpFoldResult> &results) {
337 return failure();
338 };
339
340 auto getCanonicalizationPatternsFn = [](RewritePatternSet &, MLIRContext *) {
341 };
342
343 auto populateDefaultAttrsFn = [](const OperationName &, NamedAttrList &) {};
344
345 return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
346 std::move(verifyRegionFn), std::move(parseFn),
347 std::move(printFn), std::move(foldHookFn),
348 std::move(getCanonicalizationPatternsFn),
349 std::move(populateDefaultAttrsFn));
350 }
351
get(StringRef name,ExtensibleDialect * dialect,OperationName::VerifyInvariantsFn && verifyFn,OperationName::VerifyInvariantsFn && verifyRegionFn,OperationName::ParseAssemblyFn && parseFn,OperationName::PrintAssemblyFn && printFn,OperationName::FoldHookFn && foldHookFn,GetCanonicalizationPatternsFn && getCanonicalizationPatternsFn,OperationName::PopulateDefaultAttrsFn && populateDefaultAttrsFn)352 std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
353 StringRef name, ExtensibleDialect *dialect,
354 OperationName::VerifyInvariantsFn &&verifyFn,
355 OperationName::VerifyInvariantsFn &&verifyRegionFn,
356 OperationName::ParseAssemblyFn &&parseFn,
357 OperationName::PrintAssemblyFn &&printFn,
358 OperationName::FoldHookFn &&foldHookFn,
359 GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn,
360 OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn) {
361 return std::unique_ptr<DynamicOpDefinition>(new DynamicOpDefinition(
362 name, dialect, std::move(verifyFn), std::move(verifyRegionFn),
363 std::move(parseFn), std::move(printFn), std::move(foldHookFn),
364 std::move(getCanonicalizationPatternsFn),
365 std::move(populateDefaultAttrsFn)));
366 }
367
368 //===----------------------------------------------------------------------===//
369 // Extensible dialect
370 //===----------------------------------------------------------------------===//
371
372 namespace {
373 /// Interface that can only be implemented by extensible dialects.
374 /// The interface is used to check if a dialect is extensible or not.
375 class IsExtensibleDialect : public DialectInterface::Base<IsExtensibleDialect> {
376 public:
IsExtensibleDialect(Dialect * dialect)377 IsExtensibleDialect(Dialect *dialect) : Base(dialect) {}
378
379 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsExtensibleDialect)
380 };
381 } // namespace
382
ExtensibleDialect(StringRef name,MLIRContext * ctx,TypeID typeID)383 ExtensibleDialect::ExtensibleDialect(StringRef name, MLIRContext *ctx,
384 TypeID typeID)
385 : Dialect(name, ctx, typeID) {
386 addInterfaces<IsExtensibleDialect>();
387 }
388
registerDynamicType(std::unique_ptr<DynamicTypeDefinition> && type)389 void ExtensibleDialect::registerDynamicType(
390 std::unique_ptr<DynamicTypeDefinition> &&type) {
391 DynamicTypeDefinition *typePtr = type.get();
392 TypeID typeID = type->getTypeID();
393 StringRef name = type->getName();
394 ExtensibleDialect *dialect = type->getDialect();
395
396 assert(dialect == this &&
397 "trying to register a dynamic type in the wrong dialect");
398
399 // If a type with the same name is already defined, fail.
400 auto registered = dynTypes.try_emplace(typeID, std::move(type)).second;
401 (void)registered;
402 assert(registered && "type TypeID was not unique");
403
404 registered = nameToDynTypes.insert({name, typePtr}).second;
405 (void)registered;
406 assert(registered &&
407 "Trying to create a new dynamic type with an existing name");
408
409 // The StringAttr allocates the type name StringRef for the duration of the
410 // MLIR context.
411 MLIRContext *ctx = getContext();
412 auto nameAttr =
413 StringAttr::get(ctx, getNamespace() + "." + typePtr->getName());
414
415 auto abstractType = AbstractType::get(
416 *dialect, DynamicAttr::getInterfaceMap(), DynamicType::getHasTraitFn(),
417 DynamicType::getWalkImmediateSubElementsFn(),
418 DynamicType::getReplaceImmediateSubElementsFn(), typeID, nameAttr);
419
420 /// Add the type to the dialect and the type uniquer.
421 addType(typeID, std::move(abstractType));
422 typePtr->registerInTypeUniquer();
423 }
424
registerDynamicAttr(std::unique_ptr<DynamicAttrDefinition> && attr)425 void ExtensibleDialect::registerDynamicAttr(
426 std::unique_ptr<DynamicAttrDefinition> &&attr) {
427 auto *attrPtr = attr.get();
428 auto typeID = attr->getTypeID();
429 auto name = attr->getName();
430 auto *dialect = attr->getDialect();
431
432 assert(dialect == this &&
433 "trying to register a dynamic attribute in the wrong dialect");
434
435 // If an attribute with the same name is already defined, fail.
436 auto registered = dynAttrs.try_emplace(typeID, std::move(attr)).second;
437 (void)registered;
438 assert(registered && "attribute TypeID was not unique");
439
440 registered = nameToDynAttrs.insert({name, attrPtr}).second;
441 (void)registered;
442 assert(registered &&
443 "Trying to create a new dynamic attribute with an existing name");
444
445 // The StringAttr allocates the attribute name StringRef for the duration of
446 // the MLIR context.
447 MLIRContext *ctx = getContext();
448 auto nameAttr =
449 StringAttr::get(ctx, getNamespace() + "." + attrPtr->getName());
450
451 auto abstractAttr = AbstractAttribute::get(
452 *dialect, DynamicAttr::getInterfaceMap(), DynamicAttr::getHasTraitFn(),
453 DynamicAttr::getWalkImmediateSubElementsFn(),
454 DynamicAttr::getReplaceImmediateSubElementsFn(), typeID, nameAttr);
455
456 /// Add the type to the dialect and the type uniquer.
457 addAttribute(typeID, std::move(abstractAttr));
458 attrPtr->registerInAttrUniquer();
459 }
460
registerDynamicOp(std::unique_ptr<DynamicOpDefinition> && op)461 void ExtensibleDialect::registerDynamicOp(
462 std::unique_ptr<DynamicOpDefinition> &&op) {
463 assert(op->dialect == this &&
464 "trying to register a dynamic op in the wrong dialect");
465 RegisteredOperationName::insert(std::move(op), /*attrNames=*/{});
466 }
467
classof(const Dialect * dialect)468 bool ExtensibleDialect::classof(const Dialect *dialect) {
469 return const_cast<Dialect *>(dialect)
470 ->getRegisteredInterface<IsExtensibleDialect>();
471 }
472
parseOptionalDynamicType(StringRef typeName,AsmParser & parser,Type & resultType) const473 OptionalParseResult ExtensibleDialect::parseOptionalDynamicType(
474 StringRef typeName, AsmParser &parser, Type &resultType) const {
475 DynamicTypeDefinition *typeDef = lookupTypeDefinition(typeName);
476 if (!typeDef)
477 return std::nullopt;
478
479 DynamicType dynType;
480 if (DynamicType::parse(parser, typeDef, dynType))
481 return failure();
482 resultType = dynType;
483 return success();
484 }
485
printIfDynamicType(Type type,AsmPrinter & printer)486 LogicalResult ExtensibleDialect::printIfDynamicType(Type type,
487 AsmPrinter &printer) {
488 if (auto dynType = llvm::dyn_cast<DynamicType>(type)) {
489 dynType.print(printer);
490 return success();
491 }
492 return failure();
493 }
494
parseOptionalDynamicAttr(StringRef attrName,AsmParser & parser,Attribute & resultAttr) const495 OptionalParseResult ExtensibleDialect::parseOptionalDynamicAttr(
496 StringRef attrName, AsmParser &parser, Attribute &resultAttr) const {
497 DynamicAttrDefinition *attrDef = lookupAttrDefinition(attrName);
498 if (!attrDef)
499 return std::nullopt;
500
501 DynamicAttr dynAttr;
502 if (DynamicAttr::parse(parser, attrDef, dynAttr))
503 return failure();
504 resultAttr = dynAttr;
505 return success();
506 }
507
printIfDynamicAttr(Attribute attribute,AsmPrinter & printer)508 LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute,
509 AsmPrinter &printer) {
510 if (auto dynAttr = llvm::dyn_cast<DynamicAttr>(attribute)) {
511 dynAttr.print(printer);
512 return success();
513 }
514 return failure();
515 }
516
517 //===----------------------------------------------------------------------===//
518 // Dynamic dialect
519 //===----------------------------------------------------------------------===//
520
521 namespace {
522 /// Interface that can only be implemented by extensible dialects.
523 /// The interface is used to check if a dialect is extensible or not.
524 class IsDynamicDialect : public DialectInterface::Base<IsDynamicDialect> {
525 public:
IsDynamicDialect(Dialect * dialect)526 IsDynamicDialect(Dialect *dialect) : Base(dialect) {}
527
528 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsDynamicDialect)
529 };
530 } // namespace
531
DynamicDialect(StringRef name,MLIRContext * ctx)532 DynamicDialect::DynamicDialect(StringRef name, MLIRContext *ctx)
533 : SelfOwningTypeID(),
534 ExtensibleDialect(name, ctx, SelfOwningTypeID::getTypeID()) {
535 addInterfaces<IsDynamicDialect>();
536 }
537
classof(const Dialect * dialect)538 bool DynamicDialect::classof(const Dialect *dialect) {
539 return const_cast<Dialect *>(dialect)
540 ->getRegisteredInterface<IsDynamicDialect>();
541 }
542
parseType(DialectAsmParser & parser) const543 Type DynamicDialect::parseType(DialectAsmParser &parser) const {
544 auto loc = parser.getCurrentLocation();
545 StringRef typeTag;
546 if (failed(parser.parseKeyword(&typeTag)))
547 return Type();
548
549 {
550 Type dynType;
551 auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType);
552 if (parseResult.has_value()) {
553 if (succeeded(parseResult.value()))
554 return dynType;
555 return Type();
556 }
557 }
558
559 parser.emitError(loc, "expected dynamic type");
560 return Type();
561 }
562
printType(Type type,DialectAsmPrinter & printer) const563 void DynamicDialect::printType(Type type, DialectAsmPrinter &printer) const {
564 auto wasDynamic = printIfDynamicType(type, printer);
565 (void)wasDynamic;
566 assert(succeeded(wasDynamic) &&
567 "non-dynamic type defined in dynamic dialect");
568 }
569
parseAttribute(DialectAsmParser & parser,Type type) const570 Attribute DynamicDialect::parseAttribute(DialectAsmParser &parser,
571 Type type) const {
572 auto loc = parser.getCurrentLocation();
573 StringRef typeTag;
574 if (failed(parser.parseKeyword(&typeTag)))
575 return Attribute();
576
577 {
578 Attribute dynAttr;
579 auto parseResult = parseOptionalDynamicAttr(typeTag, parser, dynAttr);
580 if (parseResult.has_value()) {
581 if (succeeded(parseResult.value()))
582 return dynAttr;
583 return Attribute();
584 }
585 }
586
587 parser.emitError(loc, "expected dynamic attribute");
588 return Attribute();
589 }
printAttribute(Attribute attr,DialectAsmPrinter & printer) const590 void DynamicDialect::printAttribute(Attribute attr,
591 DialectAsmPrinter &printer) const {
592 auto wasDynamic = printIfDynamicAttr(attr, printer);
593 (void)wasDynamic;
594 assert(succeeded(wasDynamic) &&
595 "non-dynamic attribute defined in dynamic dialect");
596 }
597