1 //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "OpFormatGen.h" 10 #include "FormatGen.h" 11 #include "OpClass.h" 12 #include "mlir/Support/LLVM.h" 13 #include "mlir/TableGen/Class.h" 14 #include "mlir/TableGen/Format.h" 15 #include "mlir/TableGen/Operator.h" 16 #include "mlir/TableGen/Trait.h" 17 #include "llvm/ADT/MapVector.h" 18 #include "llvm/ADT/Sequence.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/SmallBitVector.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 #include "llvm/Support/Signals.h" 24 #include "llvm/Support/SourceMgr.h" 25 #include "llvm/TableGen/Record.h" 26 27 #define DEBUG_TYPE "mlir-tblgen-opformatgen" 28 29 using namespace mlir; 30 using namespace mlir::tblgen; 31 using llvm::formatv; 32 using llvm::Record; 33 using llvm::StringMap; 34 35 //===----------------------------------------------------------------------===// 36 // VariableElement 37 38 namespace { 39 /// This class represents an instance of an op variable element. A variable 40 /// refers to something registered on the operation itself, e.g. an operand, 41 /// result, attribute, region, or successor. 42 template <typename VarT, VariableElement::Kind VariableKind> 43 class OpVariableElement : public VariableElementBase<VariableKind> { 44 public: 45 using Base = OpVariableElement<VarT, VariableKind>; 46 47 /// Create an op variable element with the variable value. 48 OpVariableElement(const VarT *var) : var(var) {} 49 50 /// Get the variable. 51 const VarT *getVar() const { return var; } 52 53 protected: 54 /// The op variable, e.g. a type or attribute constraint. 55 const VarT *var; 56 }; 57 58 /// This class represents a variable that refers to an attribute argument. 59 struct AttributeVariable 60 : public OpVariableElement<NamedAttribute, VariableElement::Attribute> { 61 using Base::Base; 62 63 /// Return the constant builder call for the type of this attribute, or 64 /// std::nullopt if it doesn't have one. 65 std::optional<StringRef> getTypeBuilder() const { 66 std::optional<Type> attrType = var->attr.getValueType(); 67 return attrType ? attrType->getBuilderCall() : std::nullopt; 68 } 69 70 /// Indicate if this attribute is printed "qualified" (that is it is 71 /// prefixed with the `#dialect.mnemonic`). 72 bool shouldBeQualified() { return shouldBeQualifiedFlag; } 73 void setShouldBeQualified(bool qualified = true) { 74 shouldBeQualifiedFlag = qualified; 75 } 76 77 private: 78 bool shouldBeQualifiedFlag = false; 79 }; 80 81 /// This class represents a variable that refers to an operand argument. 82 using OperandVariable = 83 OpVariableElement<NamedTypeConstraint, VariableElement::Operand>; 84 85 /// This class represents a variable that refers to a result. 86 using ResultVariable = 87 OpVariableElement<NamedTypeConstraint, VariableElement::Result>; 88 89 /// This class represents a variable that refers to a region. 90 using RegionVariable = OpVariableElement<NamedRegion, VariableElement::Region>; 91 92 /// This class represents a variable that refers to a successor. 93 using SuccessorVariable = 94 OpVariableElement<NamedSuccessor, VariableElement::Successor>; 95 96 /// This class represents a variable that refers to a property argument. 97 using PropertyVariable = 98 OpVariableElement<NamedProperty, VariableElement::Property>; 99 100 /// LLVM RTTI helper for attribute-like variables, that is, attributes or 101 /// properties. This allows for common handling of attributes and properties in 102 /// parts of the code that are oblivious to whether something is stored as an 103 /// attribute or a property. 104 struct AttributeLikeVariable : public VariableElement { 105 enum { AttributeLike = 1 << 0 }; 106 107 static bool classof(const VariableElement *ve) { 108 return ve->getKind() == VariableElement::Attribute || 109 ve->getKind() == VariableElement::Property; 110 } 111 112 static bool classof(const FormatElement *fe) { 113 return isa<VariableElement>(fe) && classof(cast<VariableElement>(fe)); 114 } 115 116 /// Returns true if the variable is a UnitAttr or a UnitProp. 117 bool isUnit() const { 118 if (const auto *attr = dyn_cast<AttributeVariable>(this)) 119 return attr->getVar()->attr.getBaseAttr().getAttrDefName() == "UnitAttr"; 120 if (const auto *prop = dyn_cast<PropertyVariable>(this)) { 121 StringRef baseDefName = 122 prop->getVar()->prop.getBaseProperty().getPropertyDefName(); 123 // Note: remove the `UnitProperty` case once the deprecation period is 124 // over. 125 return baseDefName == "UnitProp" || baseDefName == "UnitProperty"; 126 } 127 llvm_unreachable("Type that wasn't listed in classof()"); 128 } 129 130 StringRef getName() const { 131 if (const auto *attr = dyn_cast<AttributeVariable>(this)) 132 return attr->getVar()->name; 133 if (const auto *prop = dyn_cast<PropertyVariable>(this)) 134 return prop->getVar()->name; 135 llvm_unreachable("Type that wasn't listed in classof()"); 136 } 137 }; 138 } // namespace 139 140 //===----------------------------------------------------------------------===// 141 // DirectiveElement 142 143 namespace { 144 /// This class represents the `operands` directive. This directive represents 145 /// all of the operands of an operation. 146 using OperandsDirective = DirectiveElementBase<DirectiveElement::Operands>; 147 148 /// This class represents the `results` directive. This directive represents 149 /// all of the results of an operation. 150 using ResultsDirective = DirectiveElementBase<DirectiveElement::Results>; 151 152 /// This class represents the `regions` directive. This directive represents 153 /// all of the regions of an operation. 154 using RegionsDirective = DirectiveElementBase<DirectiveElement::Regions>; 155 156 /// This class represents the `successors` directive. This directive represents 157 /// all of the successors of an operation. 158 using SuccessorsDirective = DirectiveElementBase<DirectiveElement::Successors>; 159 160 /// This class represents the `attr-dict` directive. This directive represents 161 /// the attribute dictionary of the operation. 162 class AttrDictDirective 163 : public DirectiveElementBase<DirectiveElement::AttrDict> { 164 public: 165 explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {} 166 167 /// Return whether the dictionary should be printed with the 'attributes' 168 /// keyword. 169 bool isWithKeyword() const { return withKeyword; } 170 171 private: 172 /// If the dictionary should be printed with the 'attributes' keyword. 173 bool withKeyword; 174 }; 175 176 /// This class represents the `prop-dict` directive. This directive represents 177 /// the properties of the operation, expressed as a directionary. 178 class PropDictDirective 179 : public DirectiveElementBase<DirectiveElement::PropDict> { 180 public: 181 explicit PropDictDirective() = default; 182 }; 183 184 /// This class represents the `functional-type` directive. This directive takes 185 /// two arguments and formats them, respectively, as the inputs and results of a 186 /// FunctionType. 187 class FunctionalTypeDirective 188 : public DirectiveElementBase<DirectiveElement::FunctionalType> { 189 public: 190 FunctionalTypeDirective(FormatElement *inputs, FormatElement *results) 191 : inputs(inputs), results(results) {} 192 193 FormatElement *getInputs() const { return inputs; } 194 FormatElement *getResults() const { return results; } 195 196 private: 197 /// The input and result arguments. 198 FormatElement *inputs, *results; 199 }; 200 201 /// This class represents the `type` directive. 202 class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> { 203 public: 204 TypeDirective(FormatElement *arg) : arg(arg) {} 205 206 FormatElement *getArg() const { return arg; } 207 208 /// Indicate if this type is printed "qualified" (that is it is 209 /// prefixed with the `!dialect.mnemonic`). 210 bool shouldBeQualified() { return shouldBeQualifiedFlag; } 211 void setShouldBeQualified(bool qualified = true) { 212 shouldBeQualifiedFlag = qualified; 213 } 214 215 private: 216 /// The argument that is used to format the directive. 217 FormatElement *arg; 218 219 bool shouldBeQualifiedFlag = false; 220 }; 221 222 /// This class represents a group of order-independent optional clauses. Each 223 /// clause starts with a literal element and has a coressponding parsing 224 /// element. A parsing element is a continous sequence of format elements. 225 /// Each clause can appear 0 or 1 time. 226 class OIListElement : public DirectiveElementBase<DirectiveElement::OIList> { 227 public: 228 OIListElement(std::vector<FormatElement *> &&literalElements, 229 std::vector<std::vector<FormatElement *>> &&parsingElements) 230 : literalElements(std::move(literalElements)), 231 parsingElements(std::move(parsingElements)) {} 232 233 /// Returns a range to iterate over the LiteralElements. 234 auto getLiteralElements() const { 235 return llvm::map_range(literalElements, [](FormatElement *el) { 236 return cast<LiteralElement>(el); 237 }); 238 } 239 240 /// Returns a range to iterate over the parsing elements corresponding to the 241 /// clauses. 242 ArrayRef<std::vector<FormatElement *>> getParsingElements() const { 243 return parsingElements; 244 } 245 246 /// Returns a range to iterate over tuples of parsing and literal elements. 247 auto getClauses() const { 248 return llvm::zip(getLiteralElements(), getParsingElements()); 249 } 250 251 /// If the parsing element is a single UnitAttr element, then it returns the 252 /// attribute variable. Otherwise, returns nullptr. 253 AttributeLikeVariable * 254 getUnitVariableParsingElement(ArrayRef<FormatElement *> pelement) { 255 if (pelement.size() == 1) { 256 auto *attrElem = dyn_cast<AttributeLikeVariable>(pelement[0]); 257 if (attrElem && attrElem->isUnit()) 258 return attrElem; 259 } 260 return nullptr; 261 } 262 263 private: 264 /// A vector of `LiteralElement` objects. Each element stores the keyword 265 /// for one case of oilist element. For example, an oilist element along with 266 /// the `literalElements` vector: 267 /// ``` 268 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`] 269 /// literalElements = { `keyword`, `otherKeyword` } 270 /// ``` 271 std::vector<FormatElement *> literalElements; 272 273 /// A vector of valid declarative assembly format vectors. Each object in 274 /// parsing elements is a vector of elements in assembly format syntax. 275 /// For example, an oilist element along with the parsingElements vector: 276 /// ``` 277 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`] 278 /// parsingElements = { 279 /// { `=`, `(`, $arg0, `)` }, 280 /// { `<`, $arg1, `>` } 281 /// } 282 /// ``` 283 std::vector<std::vector<FormatElement *>> parsingElements; 284 }; 285 } // namespace 286 287 //===----------------------------------------------------------------------===// 288 // OperationFormat 289 //===----------------------------------------------------------------------===// 290 291 namespace { 292 293 using ConstArgument = 294 llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>; 295 296 struct OperationFormat { 297 /// This class represents a specific resolver for an operand or result type. 298 class TypeResolution { 299 public: 300 TypeResolution() = default; 301 302 /// Get the index into the buildable types for this type, or std::nullopt. 303 std::optional<int> getBuilderIdx() const { return builderIdx; } 304 void setBuilderIdx(int idx) { builderIdx = idx; } 305 306 /// Get the variable this type is resolved to, or nullptr. 307 const NamedTypeConstraint *getVariable() const { 308 return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(resolver); 309 } 310 /// Get the attribute this type is resolved to, or nullptr. 311 const NamedAttribute *getAttribute() const { 312 return llvm::dyn_cast_if_present<const NamedAttribute *>(resolver); 313 } 314 /// Get the transformer for the type of the variable, or std::nullopt. 315 std::optional<StringRef> getVarTransformer() const { 316 return variableTransformer; 317 } 318 void setResolver(ConstArgument arg, std::optional<StringRef> transformer) { 319 resolver = arg; 320 variableTransformer = transformer; 321 assert(getVariable() || getAttribute()); 322 } 323 324 private: 325 /// If the type is resolved with a buildable type, this is the index into 326 /// 'buildableTypes' in the parent format. 327 std::optional<int> builderIdx; 328 /// If the type is resolved based upon another operand or result, this is 329 /// the variable or the attribute that this type is resolved to. 330 ConstArgument resolver; 331 /// If the type is resolved based upon another operand or result, this is 332 /// a transformer to apply to the variable when resolving. 333 std::optional<StringRef> variableTransformer; 334 }; 335 336 /// The context in which an element is generated. 337 enum class GenContext { 338 /// The element is generated at the top-level or with the same behaviour. 339 Normal, 340 /// The element is generated inside an optional group. 341 Optional 342 }; 343 344 OperationFormat(const Operator &op, bool hasProperties) 345 : useProperties(hasProperties), opCppClassName(op.getCppClassName()) { 346 operandTypes.resize(op.getNumOperands(), TypeResolution()); 347 resultTypes.resize(op.getNumResults(), TypeResolution()); 348 349 hasImplicitTermTrait = llvm::any_of(op.getTraits(), [](const Trait &trait) { 350 return trait.getDef().isSubClassOf("SingleBlockImplicitTerminatorImpl"); 351 }); 352 353 hasSingleBlockTrait = op.getTrait("::mlir::OpTrait::SingleBlock"); 354 } 355 356 /// Generate the operation parser from this format. 357 void genParser(Operator &op, OpClass &opClass); 358 /// Generate the parser code for a specific format element. 359 void genElementParser(FormatElement *element, MethodBody &body, 360 FmtContext &attrTypeCtx, 361 GenContext genCtx = GenContext::Normal); 362 /// Generate the C++ to resolve the types of operands and results during 363 /// parsing. 364 void genParserTypeResolution(Operator &op, MethodBody &body); 365 /// Generate the C++ to resolve the types of the operands during parsing. 366 void genParserOperandTypeResolution( 367 Operator &op, MethodBody &body, 368 function_ref<void(TypeResolution &, StringRef)> emitTypeResolver); 369 /// Generate the C++ to resolve regions during parsing. 370 void genParserRegionResolution(Operator &op, MethodBody &body); 371 /// Generate the C++ to resolve successors during parsing. 372 void genParserSuccessorResolution(Operator &op, MethodBody &body); 373 /// Generate the C++ to handling variadic segment size traits. 374 void genParserVariadicSegmentResolution(Operator &op, MethodBody &body); 375 376 /// Generate the operation printer from this format. 377 void genPrinter(Operator &op, OpClass &opClass); 378 379 /// Generate the printer code for a specific format element. 380 void genElementPrinter(FormatElement *element, MethodBody &body, Operator &op, 381 bool &shouldEmitSpace, bool &lastWasPunctuation); 382 383 /// The various elements in this format. 384 std::vector<FormatElement *> elements; 385 386 /// A flag indicating if all operand/result types were seen. If the format 387 /// contains these, it can not contain individual type resolvers. 388 bool allOperands = false, allOperandTypes = false, allResultTypes = false; 389 390 /// A flag indicating if this operation infers its result types 391 bool infersResultTypes = false; 392 393 /// A flag indicating if this operation has the SingleBlockImplicitTerminator 394 /// trait. 395 bool hasImplicitTermTrait; 396 397 /// A flag indicating if this operation has the SingleBlock trait. 398 bool hasSingleBlockTrait; 399 400 /// Indicate whether we need to use properties for the current operator. 401 bool useProperties; 402 403 /// Indicate whether prop-dict is used in the format 404 bool hasPropDict; 405 406 /// The Operation class name 407 StringRef opCppClassName; 408 409 /// A map of buildable types to indices. 410 llvm::MapVector<StringRef, int, StringMap<int>> buildableTypes; 411 412 /// The index of the buildable type, if valid, for every operand and result. 413 std::vector<TypeResolution> operandTypes, resultTypes; 414 415 /// The set of attributes explicitly used within the format. 416 llvm::SmallSetVector<const NamedAttribute *, 8> usedAttributes; 417 llvm::StringSet<> inferredAttributes; 418 419 /// The set of properties explicitly used within the format. 420 llvm::SmallSetVector<const NamedProperty *, 8> usedProperties; 421 }; 422 } // namespace 423 424 //===----------------------------------------------------------------------===// 425 // Parser Gen 426 427 /// Returns true if we can format the given attribute as an EnumAttr in the 428 /// parser format. 429 static bool canFormatEnumAttr(const NamedAttribute *attr) { 430 Attribute baseAttr = attr->attr.getBaseAttr(); 431 const EnumAttr *enumAttr = dyn_cast<EnumAttr>(&baseAttr); 432 if (!enumAttr) 433 return false; 434 435 // The attribute must have a valid underlying type and a constant builder. 436 return !enumAttr->getUnderlyingType().empty() && 437 !enumAttr->getConstBuilderTemplate().empty(); 438 } 439 440 /// Returns if we should format the given attribute as an SymbolNameAttr. 441 static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) { 442 return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr"; 443 } 444 445 /// The code snippet used to generate a parser call for an attribute. 446 /// 447 /// {0}: The name of the attribute. 448 /// {1}: The type for the attribute. 449 const char *const attrParserCode = R"( 450 if (parser.parseCustomAttributeWithFallback({0}Attr, {1})) {{ 451 return ::mlir::failure(); 452 } 453 )"; 454 455 /// The code snippet used to generate a parser call for an attribute. 456 /// 457 /// {0}: The name of the attribute. 458 /// {1}: The type for the attribute. 459 const char *const genericAttrParserCode = R"( 460 if (parser.parseAttribute({0}Attr, {1})) 461 return ::mlir::failure(); 462 )"; 463 464 const char *const optionalAttrParserCode = R"( 465 ::mlir::OptionalParseResult parseResult{0}Attr = 466 parser.parseOptionalAttribute({0}Attr, {1}); 467 if (parseResult{0}Attr.has_value() && failed(*parseResult{0}Attr)) 468 return ::mlir::failure(); 469 if (parseResult{0}Attr.has_value() && succeeded(*parseResult{0}Attr)) 470 )"; 471 472 /// The code snippet used to generate a parser call for a symbol name attribute. 473 /// 474 /// {0}: The name of the attribute. 475 const char *const symbolNameAttrParserCode = R"( 476 if (parser.parseSymbolName({0}Attr)) 477 return ::mlir::failure(); 478 )"; 479 const char *const optionalSymbolNameAttrParserCode = R"( 480 // Parsing an optional symbol name doesn't fail, so no need to check the 481 // result. 482 (void)parser.parseOptionalSymbolName({0}Attr); 483 )"; 484 485 /// The code snippet used to generate a parser call for an enum attribute. 486 /// 487 /// {0}: The name of the attribute. 488 /// {1}: The c++ namespace for the enum symbolize functions. 489 /// {2}: The function to symbolize a string of the enum. 490 /// {3}: The constant builder call to create an attribute of the enum type. 491 /// {4}: The set of allowed enum keywords. 492 /// {5}: The error message on failure when the enum isn't present. 493 /// {6}: The attribute assignment expression 494 const char *const enumAttrParserCode = R"( 495 { 496 ::llvm::StringRef attrStr; 497 ::mlir::NamedAttrList attrStorage; 498 auto loc = parser.getCurrentLocation(); 499 if (parser.parseOptionalKeyword(&attrStr, {4})) { 500 ::mlir::StringAttr attrVal; 501 ::mlir::OptionalParseResult parseResult = 502 parser.parseOptionalAttribute(attrVal, 503 parser.getBuilder().getNoneType(), 504 "{0}", attrStorage); 505 if (parseResult.has_value()) {{ 506 if (failed(*parseResult)) 507 return ::mlir::failure(); 508 attrStr = attrVal.getValue(); 509 } else { 510 {5} 511 } 512 } 513 if (!attrStr.empty()) { 514 auto attrOptional = {1}::{2}(attrStr); 515 if (!attrOptional) 516 return parser.emitError(loc, "invalid ") 517 << "{0} attribute specification: \"" << attrStr << '"';; 518 519 {0}Attr = {3}; 520 {6} 521 } 522 } 523 )"; 524 525 /// The code snippet used to generate a parser call for a property. 526 /// {0}: The name of the property 527 /// {1}: The C++ class name of the operation 528 /// {2}: The property's parser code with appropriate substitutions performed 529 /// {3}: The description of the expected property for the error message. 530 const char *const propertyParserCode = R"( 531 auto {0}PropLoc = parser.getCurrentLocation(); 532 auto {0}PropParseResult = [&](auto& propStorage) -> ::mlir::ParseResult {{ 533 {2} 534 return ::mlir::success(); 535 }(result.getOrAddProperties<{1}::Properties>().{0}); 536 if (failed({0}PropParseResult)) {{ 537 return parser.emitError({0}PropLoc, "invalid value for property {0}, expected {3}"); 538 } 539 )"; 540 541 /// The code snippet used to generate a parser call for a property. 542 /// {0}: The name of the property 543 /// {1}: The C++ class name of the operation 544 /// {2}: The property's parser code with appropriate substitutions performed 545 const char *const optionalPropertyParserCode = R"( 546 auto {0}PropParseResult = [&](auto& propStorage) -> ::mlir::OptionalParseResult {{ 547 {2} 548 return ::mlir::success(); 549 }(result.getOrAddProperties<{1}::Properties>().{0}); 550 if ({0}PropParseResult.has_value() && failed(*{0}PropParseResult)) {{ 551 return ::mlir::failure(); 552 } 553 )"; 554 555 /// The code snippet used to generate a parser call for an operand. 556 /// 557 /// {0}: The name of the operand. 558 const char *const variadicOperandParserCode = R"( 559 {0}OperandsLoc = parser.getCurrentLocation(); 560 if (parser.parseOperandList({0}Operands)) 561 return ::mlir::failure(); 562 )"; 563 const char *const optionalOperandParserCode = R"( 564 { 565 {0}OperandsLoc = parser.getCurrentLocation(); 566 ::mlir::OpAsmParser::UnresolvedOperand operand; 567 ::mlir::OptionalParseResult parseResult = 568 parser.parseOptionalOperand(operand); 569 if (parseResult.has_value()) { 570 if (failed(*parseResult)) 571 return ::mlir::failure(); 572 {0}Operands.push_back(operand); 573 } 574 } 575 )"; 576 const char *const operandParserCode = R"( 577 {0}OperandsLoc = parser.getCurrentLocation(); 578 if (parser.parseOperand({0}RawOperand)) 579 return ::mlir::failure(); 580 )"; 581 /// The code snippet used to generate a parser call for a VariadicOfVariadic 582 /// operand. 583 /// 584 /// {0}: The name of the operand. 585 /// {1}: The name of segment size attribute. 586 const char *const variadicOfVariadicOperandParserCode = R"( 587 { 588 {0}OperandsLoc = parser.getCurrentLocation(); 589 int32_t curSize = 0; 590 do { 591 if (parser.parseOptionalLParen()) 592 break; 593 if (parser.parseOperandList({0}Operands) || parser.parseRParen()) 594 return ::mlir::failure(); 595 {0}OperandGroupSizes.push_back({0}Operands.size() - curSize); 596 curSize = {0}Operands.size(); 597 } while (succeeded(parser.parseOptionalComma())); 598 } 599 )"; 600 601 /// The code snippet used to generate a parser call for a type list. 602 /// 603 /// {0}: The name for the type list. 604 const char *const variadicOfVariadicTypeParserCode = R"( 605 do { 606 if (parser.parseOptionalLParen()) 607 break; 608 if (parser.parseOptionalRParen() && 609 (parser.parseTypeList({0}Types) || parser.parseRParen())) 610 return ::mlir::failure(); 611 } while (succeeded(parser.parseOptionalComma())); 612 )"; 613 const char *const variadicTypeParserCode = R"( 614 if (parser.parseTypeList({0}Types)) 615 return ::mlir::failure(); 616 )"; 617 const char *const optionalTypeParserCode = R"( 618 { 619 ::mlir::Type optionalType; 620 ::mlir::OptionalParseResult parseResult = 621 parser.parseOptionalType(optionalType); 622 if (parseResult.has_value()) { 623 if (failed(*parseResult)) 624 return ::mlir::failure(); 625 {0}Types.push_back(optionalType); 626 } 627 } 628 )"; 629 const char *const typeParserCode = R"( 630 { 631 {0} type; 632 if (parser.parseCustomTypeWithFallback(type)) 633 return ::mlir::failure(); 634 {1}RawType = type; 635 } 636 )"; 637 const char *const qualifiedTypeParserCode = R"( 638 if (parser.parseType({1}RawType)) 639 return ::mlir::failure(); 640 )"; 641 642 /// The code snippet used to generate a parser call for a functional type. 643 /// 644 /// {0}: The name for the input type list. 645 /// {1}: The name for the result type list. 646 const char *const functionalTypeParserCode = R"( 647 ::mlir::FunctionType {0}__{1}_functionType; 648 if (parser.parseType({0}__{1}_functionType)) 649 return ::mlir::failure(); 650 {0}Types = {0}__{1}_functionType.getInputs(); 651 {1}Types = {0}__{1}_functionType.getResults(); 652 )"; 653 654 /// The code snippet used to generate a parser call to infer return types. 655 /// 656 /// {0}: The operation class name 657 const char *const inferReturnTypesParserCode = R"( 658 ::llvm::SmallVector<::mlir::Type> inferredReturnTypes; 659 if (::mlir::failed({0}::inferReturnTypes(parser.getContext(), 660 result.location, result.operands, 661 result.attributes.getDictionary(parser.getContext()), 662 result.getRawProperties(), 663 result.regions, inferredReturnTypes))) 664 return ::mlir::failure(); 665 result.addTypes(inferredReturnTypes); 666 )"; 667 668 /// The code snippet used to generate a parser call for a region list. 669 /// 670 /// {0}: The name for the region list. 671 const char *regionListParserCode = R"( 672 { 673 std::unique_ptr<::mlir::Region> region; 674 auto firstRegionResult = parser.parseOptionalRegion(region); 675 if (firstRegionResult.has_value()) { 676 if (failed(*firstRegionResult)) 677 return ::mlir::failure(); 678 {0}Regions.emplace_back(std::move(region)); 679 680 // Parse any trailing regions. 681 while (succeeded(parser.parseOptionalComma())) { 682 region = std::make_unique<::mlir::Region>(); 683 if (parser.parseRegion(*region)) 684 return ::mlir::failure(); 685 {0}Regions.emplace_back(std::move(region)); 686 } 687 } 688 } 689 )"; 690 691 /// The code snippet used to ensure a list of regions have terminators. 692 /// 693 /// {0}: The name of the region list. 694 const char *regionListEnsureTerminatorParserCode = R"( 695 for (auto ®ion : {0}Regions) 696 ensureTerminator(*region, parser.getBuilder(), result.location); 697 )"; 698 699 /// The code snippet used to ensure a list of regions have a block. 700 /// 701 /// {0}: The name of the region list. 702 const char *regionListEnsureSingleBlockParserCode = R"( 703 for (auto ®ion : {0}Regions) 704 if (region->empty()) region->emplaceBlock(); 705 )"; 706 707 /// The code snippet used to generate a parser call for an optional region. 708 /// 709 /// {0}: The name of the region. 710 const char *optionalRegionParserCode = R"( 711 { 712 auto parseResult = parser.parseOptionalRegion(*{0}Region); 713 if (parseResult.has_value() && failed(*parseResult)) 714 return ::mlir::failure(); 715 } 716 )"; 717 718 /// The code snippet used to generate a parser call for a region. 719 /// 720 /// {0}: The name of the region. 721 const char *regionParserCode = R"( 722 if (parser.parseRegion(*{0}Region)) 723 return ::mlir::failure(); 724 )"; 725 726 /// The code snippet used to ensure a region has a terminator. 727 /// 728 /// {0}: The name of the region. 729 const char *regionEnsureTerminatorParserCode = R"( 730 ensureTerminator(*{0}Region, parser.getBuilder(), result.location); 731 )"; 732 733 /// The code snippet used to ensure a region has a block. 734 /// 735 /// {0}: The name of the region. 736 const char *regionEnsureSingleBlockParserCode = R"( 737 if ({0}Region->empty()) {0}Region->emplaceBlock(); 738 )"; 739 740 /// The code snippet used to generate a parser call for a successor list. 741 /// 742 /// {0}: The name for the successor list. 743 const char *successorListParserCode = R"( 744 { 745 ::mlir::Block *succ; 746 auto firstSucc = parser.parseOptionalSuccessor(succ); 747 if (firstSucc.has_value()) { 748 if (failed(*firstSucc)) 749 return ::mlir::failure(); 750 {0}Successors.emplace_back(succ); 751 752 // Parse any trailing successors. 753 while (succeeded(parser.parseOptionalComma())) { 754 if (parser.parseSuccessor(succ)) 755 return ::mlir::failure(); 756 {0}Successors.emplace_back(succ); 757 } 758 } 759 } 760 )"; 761 762 /// The code snippet used to generate a parser call for a successor. 763 /// 764 /// {0}: The name of the successor. 765 const char *successorParserCode = R"( 766 if (parser.parseSuccessor({0}Successor)) 767 return ::mlir::failure(); 768 )"; 769 770 /// The code snippet used to generate a parser for OIList 771 /// 772 /// {0}: literal keyword corresponding to a case for oilist 773 const char *oilistParserCode = R"( 774 if ({0}Clause) { 775 return parser.emitError(parser.getNameLoc()) 776 << "`{0}` clause can appear at most once in the expansion of the " 777 "oilist directive"; 778 } 779 {0}Clause = true; 780 )"; 781 782 namespace { 783 /// The type of length for a given parse argument. 784 enum class ArgumentLengthKind { 785 /// The argument is a variadic of a variadic, and may contain 0->N range 786 /// elements. 787 VariadicOfVariadic, 788 /// The argument is variadic, and may contain 0->N elements. 789 Variadic, 790 /// The argument is optional, and may contain 0 or 1 elements. 791 Optional, 792 /// The argument is a single element, i.e. always represents 1 element. 793 Single 794 }; 795 } // namespace 796 797 /// Get the length kind for the given constraint. 798 static ArgumentLengthKind 799 getArgumentLengthKind(const NamedTypeConstraint *var) { 800 if (var->isOptional()) 801 return ArgumentLengthKind::Optional; 802 if (var->isVariadicOfVariadic()) 803 return ArgumentLengthKind::VariadicOfVariadic; 804 if (var->isVariadic()) 805 return ArgumentLengthKind::Variadic; 806 return ArgumentLengthKind::Single; 807 } 808 809 /// Get the name used for the type list for the given type directive operand. 810 /// 'lengthKind' to the corresponding kind for the given argument. 811 static StringRef getTypeListName(FormatElement *arg, 812 ArgumentLengthKind &lengthKind) { 813 if (auto *operand = dyn_cast<OperandVariable>(arg)) { 814 lengthKind = getArgumentLengthKind(operand->getVar()); 815 return operand->getVar()->name; 816 } 817 if (auto *result = dyn_cast<ResultVariable>(arg)) { 818 lengthKind = getArgumentLengthKind(result->getVar()); 819 return result->getVar()->name; 820 } 821 lengthKind = ArgumentLengthKind::Variadic; 822 if (isa<OperandsDirective>(arg)) 823 return "allOperand"; 824 if (isa<ResultsDirective>(arg)) 825 return "allResult"; 826 llvm_unreachable("unknown 'type' directive argument"); 827 } 828 829 /// Generate the parser for a literal value. 830 static void genLiteralParser(StringRef value, MethodBody &body) { 831 // Handle the case of a keyword/identifier. 832 if (value.front() == '_' || isalpha(value.front())) { 833 body << "Keyword(\"" << value << "\")"; 834 return; 835 } 836 body << (StringRef)StringSwitch<StringRef>(value) 837 .Case("->", "Arrow()") 838 .Case(":", "Colon()") 839 .Case(",", "Comma()") 840 .Case("=", "Equal()") 841 .Case("<", "Less()") 842 .Case(">", "Greater()") 843 .Case("{", "LBrace()") 844 .Case("}", "RBrace()") 845 .Case("(", "LParen()") 846 .Case(")", "RParen()") 847 .Case("[", "LSquare()") 848 .Case("]", "RSquare()") 849 .Case("?", "Question()") 850 .Case("+", "Plus()") 851 .Case("*", "Star()") 852 .Case("...", "Ellipsis()"); 853 } 854 855 /// Generate the storage code required for parsing the given element. 856 static void genElementParserStorage(FormatElement *element, const Operator &op, 857 MethodBody &body) { 858 if (auto *optional = dyn_cast<OptionalElement>(element)) { 859 ArrayRef<FormatElement *> elements = optional->getThenElements(); 860 861 // If the anchor is a unit attribute, it won't be parsed directly so elide 862 // it. 863 auto *anchor = dyn_cast<AttributeLikeVariable>(optional->getAnchor()); 864 FormatElement *elidedAnchorElement = nullptr; 865 if (anchor && anchor != elements.front() && anchor->isUnit()) 866 elidedAnchorElement = anchor; 867 for (FormatElement *childElement : elements) 868 if (childElement != elidedAnchorElement) 869 genElementParserStorage(childElement, op, body); 870 for (FormatElement *childElement : optional->getElseElements()) 871 genElementParserStorage(childElement, op, body); 872 873 } else if (auto *oilist = dyn_cast<OIListElement>(element)) { 874 for (ArrayRef<FormatElement *> pelement : oilist->getParsingElements()) { 875 if (!oilist->getUnitVariableParsingElement(pelement)) 876 for (FormatElement *element : pelement) 877 genElementParserStorage(element, op, body); 878 } 879 880 } else if (auto *custom = dyn_cast<CustomDirective>(element)) { 881 for (FormatElement *paramElement : custom->getArguments()) 882 genElementParserStorage(paramElement, op, body); 883 884 } else if (isa<OperandsDirective>(element)) { 885 body << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> " 886 "allOperands;\n"; 887 888 } else if (isa<RegionsDirective>(element)) { 889 body << " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> " 890 "fullRegions;\n"; 891 892 } else if (isa<SuccessorsDirective>(element)) { 893 body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n"; 894 895 } else if (auto *attr = dyn_cast<AttributeVariable>(element)) { 896 const NamedAttribute *var = attr->getVar(); 897 body << formatv(" {0} {1}Attr;\n", var->attr.getStorageType(), var->name); 898 899 } else if (auto *operand = dyn_cast<OperandVariable>(element)) { 900 StringRef name = operand->getVar()->name; 901 if (operand->getVar()->isVariableLength()) { 902 body 903 << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> " 904 << name << "Operands;\n"; 905 if (operand->getVar()->isVariadicOfVariadic()) { 906 body << " llvm::SmallVector<int32_t> " << name 907 << "OperandGroupSizes;\n"; 908 } 909 } else { 910 body << " ::mlir::OpAsmParser::UnresolvedOperand " << name 911 << "RawOperand{};\n" 912 << " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> " 913 << name << "Operands(&" << name << "RawOperand, 1);"; 914 } 915 body << formatv(" ::llvm::SMLoc {0}OperandsLoc;\n" 916 " (void){0}OperandsLoc;\n", 917 name); 918 919 } else if (auto *region = dyn_cast<RegionVariable>(element)) { 920 StringRef name = region->getVar()->name; 921 if (region->getVar()->isVariadic()) { 922 body << formatv( 923 " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> " 924 "{0}Regions;\n", 925 name); 926 } else { 927 body << formatv(" std::unique_ptr<::mlir::Region> {0}Region = " 928 "std::make_unique<::mlir::Region>();\n", 929 name); 930 } 931 932 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) { 933 StringRef name = successor->getVar()->name; 934 if (successor->getVar()->isVariadic()) { 935 body << formatv(" ::llvm::SmallVector<::mlir::Block *, 2> " 936 "{0}Successors;\n", 937 name); 938 } else { 939 body << formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name); 940 } 941 942 } else if (auto *dir = dyn_cast<TypeDirective>(element)) { 943 ArgumentLengthKind lengthKind; 944 StringRef name = getTypeListName(dir->getArg(), lengthKind); 945 if (lengthKind != ArgumentLengthKind::Single) 946 body << " ::llvm::SmallVector<::mlir::Type, 1> " << name << "Types;\n"; 947 else 948 body 949 << formatv(" ::mlir::Type {0}RawType{{};\n", name) 950 << formatv( 951 " ::llvm::ArrayRef<::mlir::Type> {0}Types(&{0}RawType, 1);\n", 952 name); 953 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) { 954 ArgumentLengthKind ignored; 955 body << " ::llvm::ArrayRef<::mlir::Type> " 956 << getTypeListName(dir->getInputs(), ignored) << "Types;\n"; 957 body << " ::llvm::ArrayRef<::mlir::Type> " 958 << getTypeListName(dir->getResults(), ignored) << "Types;\n"; 959 } 960 } 961 962 /// Generate the parser for a parameter to a custom directive. 963 static void genCustomParameterParser(FormatElement *param, MethodBody &body) { 964 if (auto *attr = dyn_cast<AttributeVariable>(param)) { 965 body << attr->getVar()->name << "Attr"; 966 } else if (isa<AttrDictDirective>(param)) { 967 body << "result.attributes"; 968 } else if (isa<PropDictDirective>(param)) { 969 body << "result"; 970 } else if (auto *operand = dyn_cast<OperandVariable>(param)) { 971 StringRef name = operand->getVar()->name; 972 ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); 973 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) 974 body << formatv("{0}OperandGroups", name); 975 else if (lengthKind == ArgumentLengthKind::Variadic) 976 body << formatv("{0}Operands", name); 977 else if (lengthKind == ArgumentLengthKind::Optional) 978 body << formatv("{0}Operand", name); 979 else 980 body << formatv("{0}RawOperand", name); 981 982 } else if (auto *region = dyn_cast<RegionVariable>(param)) { 983 StringRef name = region->getVar()->name; 984 if (region->getVar()->isVariadic()) 985 body << formatv("{0}Regions", name); 986 else 987 body << formatv("*{0}Region", name); 988 989 } else if (auto *successor = dyn_cast<SuccessorVariable>(param)) { 990 StringRef name = successor->getVar()->name; 991 if (successor->getVar()->isVariadic()) 992 body << formatv("{0}Successors", name); 993 else 994 body << formatv("{0}Successor", name); 995 996 } else if (auto *dir = dyn_cast<RefDirective>(param)) { 997 genCustomParameterParser(dir->getArg(), body); 998 999 } else if (auto *dir = dyn_cast<TypeDirective>(param)) { 1000 ArgumentLengthKind lengthKind; 1001 StringRef listName = getTypeListName(dir->getArg(), lengthKind); 1002 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) 1003 body << formatv("{0}TypeGroups", listName); 1004 else if (lengthKind == ArgumentLengthKind::Variadic) 1005 body << formatv("{0}Types", listName); 1006 else if (lengthKind == ArgumentLengthKind::Optional) 1007 body << formatv("{0}Type", listName); 1008 else 1009 body << formatv("{0}RawType", listName); 1010 1011 } else if (auto *string = dyn_cast<StringElement>(param)) { 1012 FmtContext ctx; 1013 ctx.withBuilder("parser.getBuilder()"); 1014 ctx.addSubst("_ctxt", "parser.getContext()"); 1015 body << tgfmt(string->getValue(), &ctx); 1016 1017 } else if (auto *property = dyn_cast<PropertyVariable>(param)) { 1018 body << formatv("result.getOrAddProperties<Properties>().{0}", 1019 property->getVar()->name); 1020 } else { 1021 llvm_unreachable("unknown custom directive parameter"); 1022 } 1023 } 1024 1025 /// Generate the parser for a custom directive. 1026 static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body, 1027 bool useProperties, 1028 StringRef opCppClassName, 1029 bool isOptional = false) { 1030 body << " {\n"; 1031 1032 // Preprocess the directive variables. 1033 // * Add a local variable for optional operands and types. This provides a 1034 // better API to the user defined parser methods. 1035 // * Set the location of operand variables. 1036 for (FormatElement *param : dir->getArguments()) { 1037 if (auto *operand = dyn_cast<OperandVariable>(param)) { 1038 auto *var = operand->getVar(); 1039 body << " " << var->name 1040 << "OperandsLoc = parser.getCurrentLocation();\n"; 1041 if (var->isOptional()) { 1042 body << formatv( 1043 " ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> " 1044 "{0}Operand;\n", 1045 var->name); 1046 } else if (var->isVariadicOfVariadic()) { 1047 body << formatv(" " 1048 "::llvm::SmallVector<::llvm::SmallVector<::mlir::" 1049 "OpAsmParser::UnresolvedOperand>> " 1050 "{0}OperandGroups;\n", 1051 var->name); 1052 } 1053 } else if (auto *dir = dyn_cast<TypeDirective>(param)) { 1054 ArgumentLengthKind lengthKind; 1055 StringRef listName = getTypeListName(dir->getArg(), lengthKind); 1056 if (lengthKind == ArgumentLengthKind::Optional) { 1057 body << formatv(" ::mlir::Type {0}Type;\n", listName); 1058 } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { 1059 body << formatv( 1060 " ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> " 1061 "{0}TypeGroups;\n", 1062 listName); 1063 } 1064 } else if (auto *dir = dyn_cast<RefDirective>(param)) { 1065 FormatElement *input = dir->getArg(); 1066 if (auto *operand = dyn_cast<OperandVariable>(input)) { 1067 if (!operand->getVar()->isOptional()) 1068 continue; 1069 body << formatv( 1070 " {0} {1}Operand = {1}Operands.empty() ? {0}() : " 1071 "{1}Operands[0];\n", 1072 "::std::optional<::mlir::OpAsmParser::UnresolvedOperand>", 1073 operand->getVar()->name); 1074 1075 } else if (auto *type = dyn_cast<TypeDirective>(input)) { 1076 ArgumentLengthKind lengthKind; 1077 StringRef listName = getTypeListName(type->getArg(), lengthKind); 1078 if (lengthKind == ArgumentLengthKind::Optional) { 1079 body << formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? " 1080 "::mlir::Type() : {0}Types[0];\n", 1081 listName); 1082 } 1083 } 1084 } 1085 } 1086 1087 body << " auto odsResult = parse" << dir->getName() << "(parser"; 1088 for (FormatElement *param : dir->getArguments()) { 1089 body << ", "; 1090 genCustomParameterParser(param, body); 1091 } 1092 body << ");\n"; 1093 1094 if (isOptional) { 1095 body << " if (!odsResult.has_value()) return {};\n" 1096 << " if (::mlir::failed(*odsResult)) return ::mlir::failure();\n"; 1097 } else { 1098 body << " if (odsResult) return ::mlir::failure();\n"; 1099 } 1100 1101 // After parsing, add handling for any of the optional constructs. 1102 for (FormatElement *param : dir->getArguments()) { 1103 if (auto *attr = dyn_cast<AttributeVariable>(param)) { 1104 const NamedAttribute *var = attr->getVar(); 1105 if (var->attr.isOptional() || var->attr.hasDefaultValue()) 1106 body << formatv(" if ({0}Attr)\n ", var->name); 1107 if (useProperties) { 1108 body << formatv( 1109 " result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n", 1110 var->name, opCppClassName); 1111 } else { 1112 body << formatv(" result.addAttribute(\"{0}\", {0}Attr);\n", 1113 var->name); 1114 } 1115 } else if (auto *operand = dyn_cast<OperandVariable>(param)) { 1116 const NamedTypeConstraint *var = operand->getVar(); 1117 if (var->isOptional()) { 1118 body << formatv(" if ({0}Operand.has_value())\n" 1119 " {0}Operands.push_back(*{0}Operand);\n", 1120 var->name); 1121 } else if (var->isVariadicOfVariadic()) { 1122 body << formatv( 1123 " for (const auto &subRange : {0}OperandGroups) {{\n" 1124 " {0}Operands.append(subRange.begin(), subRange.end());\n" 1125 " {0}OperandGroupSizes.push_back(subRange.size());\n" 1126 " }\n", 1127 var->name); 1128 } 1129 } else if (auto *dir = dyn_cast<TypeDirective>(param)) { 1130 ArgumentLengthKind lengthKind; 1131 StringRef listName = getTypeListName(dir->getArg(), lengthKind); 1132 if (lengthKind == ArgumentLengthKind::Optional) { 1133 body << formatv(" if ({0}Type)\n" 1134 " {0}Types.push_back({0}Type);\n", 1135 listName); 1136 } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { 1137 body << formatv( 1138 " for (const auto &subRange : {0}TypeGroups)\n" 1139 " {0}Types.append(subRange.begin(), subRange.end());\n", 1140 listName); 1141 } 1142 } 1143 } 1144 1145 body << " }\n"; 1146 } 1147 1148 /// Generate the parser for a enum attribute. 1149 static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body, 1150 FmtContext &attrTypeCtx, bool parseAsOptional, 1151 bool useProperties, StringRef opCppClassName) { 1152 Attribute baseAttr = var->attr.getBaseAttr(); 1153 const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr); 1154 std::vector<EnumAttrCase> cases = enumAttr.getAllCases(); 1155 1156 // Generate the code for building an attribute for this enum. 1157 std::string attrBuilderStr; 1158 { 1159 llvm::raw_string_ostream os(attrBuilderStr); 1160 os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx, 1161 "*attrOptional"); 1162 } 1163 1164 // Build a string containing the cases that can be formatted as a keyword. 1165 std::string validCaseKeywordsStr = "{"; 1166 llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr); 1167 for (const EnumAttrCase &attrCase : cases) 1168 if (canFormatStringAsKeyword(attrCase.getStr())) 1169 validCaseKeywordsOS << '"' << attrCase.getStr() << "\","; 1170 validCaseKeywordsOS.str().back() = '}'; 1171 1172 // If the attribute is not optional, build an error message for the missing 1173 // attribute. 1174 std::string errorMessage; 1175 if (!parseAsOptional) { 1176 llvm::raw_string_ostream errorMessageOS(errorMessage); 1177 errorMessageOS 1178 << "return parser.emitError(loc, \"expected string or " 1179 "keyword containing one of the following enum values for attribute '" 1180 << var->name << "' ["; 1181 llvm::interleaveComma(cases, errorMessageOS, [&](const auto &attrCase) { 1182 errorMessageOS << attrCase.getStr(); 1183 }); 1184 errorMessageOS << "]\");"; 1185 } 1186 std::string attrAssignment; 1187 if (useProperties) { 1188 attrAssignment = 1189 formatv(" " 1190 "result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;", 1191 var->name, opCppClassName); 1192 } else { 1193 attrAssignment = 1194 formatv("result.addAttribute(\"{0}\", {0}Attr);", var->name); 1195 } 1196 1197 body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(), 1198 enumAttr.getStringToSymbolFnName(), attrBuilderStr, 1199 validCaseKeywordsStr, errorMessage, attrAssignment); 1200 } 1201 1202 // Generate the parser for a property. 1203 static void genPropertyParser(PropertyVariable *propVar, MethodBody &body, 1204 StringRef opCppClassName, 1205 bool requireParse = true) { 1206 StringRef name = propVar->getVar()->name; 1207 const Property &prop = propVar->getVar()->prop; 1208 bool parseOptionally = 1209 prop.hasDefaultValue() && !requireParse && prop.hasOptionalParser(); 1210 FmtContext fmtContext; 1211 fmtContext.addSubst("_parser", "parser"); 1212 fmtContext.addSubst("_ctxt", "parser.getContext()"); 1213 fmtContext.addSubst("_storage", "propStorage"); 1214 1215 if (parseOptionally) { 1216 body << formatv(optionalPropertyParserCode, name, opCppClassName, 1217 tgfmt(prop.getOptionalParserCall(), &fmtContext)); 1218 } else { 1219 body << formatv(propertyParserCode, name, opCppClassName, 1220 tgfmt(prop.getParserCall(), &fmtContext), 1221 prop.getSummary()); 1222 } 1223 } 1224 1225 // Generate the parser for an attribute. 1226 static void genAttrParser(AttributeVariable *attr, MethodBody &body, 1227 FmtContext &attrTypeCtx, bool parseAsOptional, 1228 bool useProperties, StringRef opCppClassName) { 1229 const NamedAttribute *var = attr->getVar(); 1230 1231 // Check to see if we can parse this as an enum attribute. 1232 if (canFormatEnumAttr(var)) 1233 return genEnumAttrParser(var, body, attrTypeCtx, parseAsOptional, 1234 useProperties, opCppClassName); 1235 1236 // Check to see if we should parse this as a symbol name attribute. 1237 if (shouldFormatSymbolNameAttr(var)) { 1238 body << formatv(parseAsOptional ? optionalSymbolNameAttrParserCode 1239 : symbolNameAttrParserCode, 1240 var->name); 1241 } else { 1242 1243 // If this attribute has a buildable type, use that when parsing the 1244 // attribute. 1245 std::string attrTypeStr; 1246 if (std::optional<StringRef> typeBuilder = attr->getTypeBuilder()) { 1247 llvm::raw_string_ostream os(attrTypeStr); 1248 os << tgfmt(*typeBuilder, &attrTypeCtx); 1249 } else { 1250 attrTypeStr = "::mlir::Type{}"; 1251 } 1252 if (parseAsOptional) { 1253 body << formatv(optionalAttrParserCode, var->name, attrTypeStr); 1254 } else { 1255 if (attr->shouldBeQualified() || 1256 var->attr.getStorageType() == "::mlir::Attribute") 1257 body << formatv(genericAttrParserCode, var->name, attrTypeStr); 1258 else 1259 body << formatv(attrParserCode, var->name, attrTypeStr); 1260 } 1261 } 1262 if (useProperties) { 1263 body << formatv( 1264 " if ({0}Attr) result.getOrAddProperties<{1}::Properties>().{0} = " 1265 "{0}Attr;\n", 1266 var->name, opCppClassName); 1267 } else { 1268 body << formatv( 1269 " if ({0}Attr) result.attributes.append(\"{0}\", {0}Attr);\n", 1270 var->name); 1271 } 1272 } 1273 1274 // Generates the 'setPropertiesFromParsedAttr' used to set properties from a 1275 // 'prop-dict' dictionary attr. 1276 static void genParsedAttrPropertiesSetter(OperationFormat &fmt, Operator &op, 1277 OpClass &opClass) { 1278 // Not required unless 'prop-dict' is present or we are not using properties. 1279 if (!fmt.hasPropDict || !fmt.useProperties) 1280 return; 1281 1282 SmallVector<MethodParameter> paramList; 1283 paramList.emplace_back("Properties &", "prop"); 1284 paramList.emplace_back("::mlir::Attribute", "attr"); 1285 paramList.emplace_back("::llvm::function_ref<::mlir::InFlightDiagnostic()>", 1286 "emitError"); 1287 1288 Method *method = opClass.addStaticMethod("::llvm::LogicalResult", 1289 "setPropertiesFromParsedAttr", 1290 std::move(paramList)); 1291 MethodBody &body = method->body().indent(); 1292 1293 body << R"decl( 1294 ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); 1295 if (!dict) { 1296 emitError() << "expected DictionaryAttr to set properties"; 1297 return ::mlir::failure(); 1298 } 1299 )decl"; 1300 1301 // {0}: fromAttribute call 1302 // {1}: property name 1303 // {2}: isRequired 1304 const char *propFromAttrFmt = R"decl( 1305 auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr, 1306 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) -> ::mlir::LogicalResult {{ 1307 {0}; 1308 }; 1309 auto attr = dict.get("{1}"); 1310 if (!attr && {2}) {{ 1311 emitError() << "expected key entry for {1} in DictionaryAttr to set " 1312 "Properties."; 1313 return ::mlir::failure(); 1314 } 1315 if (attr && ::mlir::failed(setFromAttr(prop.{1}, attr, emitError))) 1316 return ::mlir::failure(); 1317 )decl"; 1318 1319 // Generate the setter for any property not parsed elsewhere. 1320 for (const NamedProperty &namedProperty : op.getProperties()) { 1321 if (fmt.usedProperties.contains(&namedProperty)) 1322 continue; 1323 1324 auto scope = body.scope("{\n", "}\n", /*indent=*/true); 1325 1326 StringRef name = namedProperty.name; 1327 const Property &prop = namedProperty.prop; 1328 bool isRequired = !prop.hasDefaultValue(); 1329 FmtContext fctx; 1330 body << formatv(propFromAttrFmt, 1331 tgfmt(prop.getConvertFromAttributeCall(), 1332 &fctx.addSubst("_attr", "propAttr") 1333 .addSubst("_storage", "propStorage") 1334 .addSubst("_diag", "emitError")), 1335 name, isRequired); 1336 } 1337 1338 // Generate the setter for any attribute not parsed elsewhere. 1339 for (const NamedAttribute &namedAttr : op.getAttributes()) { 1340 if (fmt.usedAttributes.contains(&namedAttr)) 1341 continue; 1342 1343 const Attribute &attr = namedAttr.attr; 1344 // Derived attributes do not need to be parsed. 1345 if (attr.isDerivedAttr()) 1346 continue; 1347 1348 auto scope = body.scope("{\n", "}\n", /*indent=*/true); 1349 1350 // If the attribute has a default value or is optional, it does not need to 1351 // be present in the parsed dictionary attribute. 1352 bool isRequired = !attr.isOptional() && !attr.hasDefaultValue(); 1353 body << formatv(R"decl( 1354 auto &propStorage = prop.{0}; 1355 auto attr = dict.get("{0}"); 1356 if (attr || /*isRequired=*/{1}) {{ 1357 if (!attr) {{ 1358 emitError() << "expected key entry for {0} in DictionaryAttr to set " 1359 "Properties."; 1360 return ::mlir::failure(); 1361 } 1362 auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr); 1363 if (convertedAttr) {{ 1364 propStorage = convertedAttr; 1365 } else {{ 1366 emitError() << "Invalid attribute `{0}` in property conversion: " << attr; 1367 return ::mlir::failure(); 1368 } 1369 } 1370 )decl", 1371 namedAttr.name, isRequired); 1372 } 1373 body << "return ::mlir::success();\n"; 1374 } 1375 1376 void OperationFormat::genParser(Operator &op, OpClass &opClass) { 1377 SmallVector<MethodParameter> paramList; 1378 paramList.emplace_back("::mlir::OpAsmParser &", "parser"); 1379 paramList.emplace_back("::mlir::OperationState &", "result"); 1380 1381 auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse", 1382 std::move(paramList)); 1383 auto &body = method->body(); 1384 1385 // Generate variables to store the operands and type within the format. This 1386 // allows for referencing these variables in the presence of optional 1387 // groupings. 1388 for (FormatElement *element : elements) 1389 genElementParserStorage(element, op, body); 1390 1391 // A format context used when parsing attributes with buildable types. 1392 FmtContext attrTypeCtx; 1393 attrTypeCtx.withBuilder("parser.getBuilder()"); 1394 1395 // Generate parsers for each of the elements. 1396 for (FormatElement *element : elements) 1397 genElementParser(element, body, attrTypeCtx); 1398 1399 // Generate the code to resolve the operand/result types and successors now 1400 // that they have been parsed. 1401 genParserRegionResolution(op, body); 1402 genParserSuccessorResolution(op, body); 1403 genParserVariadicSegmentResolution(op, body); 1404 genParserTypeResolution(op, body); 1405 1406 body << " return ::mlir::success();\n"; 1407 1408 genParsedAttrPropertiesSetter(*this, op, opClass); 1409 } 1410 1411 void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, 1412 FmtContext &attrTypeCtx, 1413 GenContext genCtx) { 1414 /// Optional Group. 1415 if (auto *optional = dyn_cast<OptionalElement>(element)) { 1416 auto genElementParsers = [&](FormatElement *firstElement, 1417 ArrayRef<FormatElement *> elements, 1418 bool thenGroup) { 1419 // If the anchor is a unit attribute, we don't need to print it. When 1420 // parsing, we will add this attribute if this group is present. 1421 FormatElement *elidedAnchorElement = nullptr; 1422 auto *anchorVar = dyn_cast<AttributeLikeVariable>(optional->getAnchor()); 1423 if (anchorVar && anchorVar != firstElement && anchorVar->isUnit()) { 1424 elidedAnchorElement = anchorVar; 1425 1426 if (!thenGroup == optional->isInverted()) { 1427 // Add the anchor unit attribute or property to the operation state 1428 // or set the property to true. 1429 if (isa<PropertyVariable>(anchorVar)) { 1430 body << formatv( 1431 " result.getOrAddProperties<{1}::Properties>().{0} = true;", 1432 anchorVar->getName(), opCppClassName); 1433 } else if (useProperties) { 1434 body << formatv( 1435 " result.getOrAddProperties<{1}::Properties>().{0} = " 1436 "parser.getBuilder().getUnitAttr();", 1437 anchorVar->getName(), opCppClassName); 1438 } else { 1439 body << " result.addAttribute(\"" << anchorVar->getName() 1440 << "\", parser.getBuilder().getUnitAttr());\n"; 1441 } 1442 } 1443 } 1444 1445 // Generate the rest of the elements inside an optional group. Elements in 1446 // an optional group after the guard are parsed as required. 1447 for (FormatElement *childElement : elements) 1448 if (childElement != elidedAnchorElement) 1449 genElementParser(childElement, body, attrTypeCtx, 1450 GenContext::Optional); 1451 }; 1452 1453 ArrayRef<FormatElement *> thenElements = 1454 optional->getThenElements(/*parseable=*/true); 1455 1456 // Generate a special optional parser for the first element to gate the 1457 // parsing of the rest of the elements. 1458 FormatElement *firstElement = thenElements.front(); 1459 if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) { 1460 genAttrParser(attrVar, body, attrTypeCtx, /*parseAsOptional=*/true, 1461 useProperties, opCppClassName); 1462 body << " if (" << attrVar->getVar()->name << "Attr) {\n"; 1463 } else if (auto *propVar = dyn_cast<PropertyVariable>(firstElement)) { 1464 genPropertyParser(propVar, body, opCppClassName, /*requireParse=*/false); 1465 body << formatv("if ({0}PropParseResult.has_value() && " 1466 "succeeded(*{0}PropParseResult)) ", 1467 propVar->getVar()->name) 1468 << " {\n"; 1469 } else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) { 1470 body << " if (::mlir::succeeded(parser.parseOptional"; 1471 genLiteralParser(literal->getSpelling(), body); 1472 body << ")) {\n"; 1473 } else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) { 1474 genElementParser(opVar, body, attrTypeCtx); 1475 body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n"; 1476 } else if (auto *regionVar = dyn_cast<RegionVariable>(firstElement)) { 1477 const NamedRegion *region = regionVar->getVar(); 1478 if (region->isVariadic()) { 1479 genElementParser(regionVar, body, attrTypeCtx); 1480 body << " if (!" << region->name << "Regions.empty()) {\n"; 1481 } else { 1482 body << formatv(optionalRegionParserCode, region->name); 1483 body << " if (!" << region->name << "Region->empty()) {\n "; 1484 if (hasImplicitTermTrait) 1485 body << formatv(regionEnsureTerminatorParserCode, region->name); 1486 else if (hasSingleBlockTrait) 1487 body << formatv(regionEnsureSingleBlockParserCode, region->name); 1488 } 1489 } else if (auto *custom = dyn_cast<CustomDirective>(firstElement)) { 1490 body << " if (auto optResult = [&]() -> ::mlir::OptionalParseResult {\n"; 1491 genCustomDirectiveParser(custom, body, useProperties, opCppClassName, 1492 /*isOptional=*/true); 1493 body << " return ::mlir::success();\n" 1494 << " }(); optResult.has_value() && ::mlir::failed(*optResult)) {\n" 1495 << " return ::mlir::failure();\n" 1496 << " } else if (optResult.has_value()) {\n"; 1497 } 1498 1499 genElementParsers(firstElement, thenElements.drop_front(), 1500 /*thenGroup=*/true); 1501 body << " }"; 1502 1503 // Generate the else elements. 1504 auto elseElements = optional->getElseElements(); 1505 if (!elseElements.empty()) { 1506 body << " else {\n"; 1507 ArrayRef<FormatElement *> elseElements = 1508 optional->getElseElements(/*parseable=*/true); 1509 genElementParsers(elseElements.front(), elseElements, 1510 /*thenGroup=*/false); 1511 body << " }"; 1512 } 1513 body << "\n"; 1514 1515 /// OIList Directive 1516 } else if (OIListElement *oilist = dyn_cast<OIListElement>(element)) { 1517 for (LiteralElement *le : oilist->getLiteralElements()) 1518 body << " bool " << le->getSpelling() << "Clause = false;\n"; 1519 1520 // Generate the parsing loop 1521 body << " while(true) {\n"; 1522 for (auto clause : oilist->getClauses()) { 1523 LiteralElement *lelement = std::get<0>(clause); 1524 ArrayRef<FormatElement *> pelement = std::get<1>(clause); 1525 body << "if (succeeded(parser.parseOptional"; 1526 genLiteralParser(lelement->getSpelling(), body); 1527 body << ")) {\n"; 1528 StringRef lelementName = lelement->getSpelling(); 1529 body << formatv(oilistParserCode, lelementName); 1530 if (AttributeLikeVariable *unitVarElem = 1531 oilist->getUnitVariableParsingElement(pelement)) { 1532 if (isa<PropertyVariable>(unitVarElem)) { 1533 body << formatv( 1534 " result.getOrAddProperties<{1}::Properties>().{0} = true;", 1535 unitVarElem->getName(), opCppClassName); 1536 } else if (useProperties) { 1537 body << formatv( 1538 " result.getOrAddProperties<{1}::Properties>().{0} = " 1539 "parser.getBuilder().getUnitAttr();", 1540 unitVarElem->getName(), opCppClassName); 1541 } else { 1542 body << " result.addAttribute(\"" << unitVarElem->getName() 1543 << "\", UnitAttr::get(parser.getContext()));\n"; 1544 } 1545 } else { 1546 for (FormatElement *el : pelement) 1547 genElementParser(el, body, attrTypeCtx); 1548 } 1549 body << " } else "; 1550 } 1551 body << " {\n"; 1552 body << " break;\n"; 1553 body << " }\n"; 1554 body << "}\n"; 1555 1556 /// Literals. 1557 } else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) { 1558 body << " if (parser.parse"; 1559 genLiteralParser(literal->getSpelling(), body); 1560 body << ")\n return ::mlir::failure();\n"; 1561 1562 /// Whitespaces. 1563 } else if (isa<WhitespaceElement>(element)) { 1564 // Nothing to parse. 1565 1566 /// Arguments. 1567 } else if (auto *attr = dyn_cast<AttributeVariable>(element)) { 1568 bool parseAsOptional = 1569 (genCtx == GenContext::Normal && attr->getVar()->attr.isOptional()); 1570 genAttrParser(attr, body, attrTypeCtx, parseAsOptional, useProperties, 1571 opCppClassName); 1572 } else if (auto *prop = dyn_cast<PropertyVariable>(element)) { 1573 genPropertyParser(prop, body, opCppClassName); 1574 1575 } else if (auto *operand = dyn_cast<OperandVariable>(element)) { 1576 ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); 1577 StringRef name = operand->getVar()->name; 1578 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) 1579 body << formatv(variadicOfVariadicOperandParserCode, name); 1580 else if (lengthKind == ArgumentLengthKind::Variadic) 1581 body << formatv(variadicOperandParserCode, name); 1582 else if (lengthKind == ArgumentLengthKind::Optional) 1583 body << formatv(optionalOperandParserCode, name); 1584 else 1585 body << formatv(operandParserCode, name); 1586 1587 } else if (auto *region = dyn_cast<RegionVariable>(element)) { 1588 bool isVariadic = region->getVar()->isVariadic(); 1589 body << formatv(isVariadic ? regionListParserCode : regionParserCode, 1590 region->getVar()->name); 1591 if (hasImplicitTermTrait) 1592 body << formatv(isVariadic ? regionListEnsureTerminatorParserCode 1593 : regionEnsureTerminatorParserCode, 1594 region->getVar()->name); 1595 else if (hasSingleBlockTrait) 1596 body << formatv(isVariadic ? regionListEnsureSingleBlockParserCode 1597 : regionEnsureSingleBlockParserCode, 1598 region->getVar()->name); 1599 1600 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) { 1601 bool isVariadic = successor->getVar()->isVariadic(); 1602 body << formatv(isVariadic ? successorListParserCode : successorParserCode, 1603 successor->getVar()->name); 1604 1605 /// Directives. 1606 } else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) { 1607 body.indent() << "{\n"; 1608 body.indent() << "auto loc = parser.getCurrentLocation();(void)loc;\n" 1609 << "if (parser.parseOptionalAttrDict" 1610 << (attrDict->isWithKeyword() ? "WithKeyword" : "") 1611 << "(result.attributes))\n" 1612 << " return ::mlir::failure();\n"; 1613 if (useProperties) { 1614 body << "if (failed(verifyInherentAttrs(result.name, result.attributes, " 1615 "[&]() {\n" 1616 << " return parser.emitError(loc) << \"'\" << " 1617 "result.name.getStringRef() << \"' op \";\n" 1618 << " })))\n" 1619 << " return ::mlir::failure();\n"; 1620 } 1621 body.unindent() << "}\n"; 1622 body.unindent(); 1623 } else if (isa<PropDictDirective>(element)) { 1624 if (useProperties) { 1625 body << " if (parseProperties(parser, result))\n" 1626 << " return ::mlir::failure();\n"; 1627 } 1628 } else if (auto *customDir = dyn_cast<CustomDirective>(element)) { 1629 genCustomDirectiveParser(customDir, body, useProperties, opCppClassName); 1630 } else if (isa<OperandsDirective>(element)) { 1631 body << " [[maybe_unused]] ::llvm::SMLoc allOperandLoc =" 1632 << " parser.getCurrentLocation();\n" 1633 << " if (parser.parseOperandList(allOperands))\n" 1634 << " return ::mlir::failure();\n"; 1635 1636 } else if (isa<RegionsDirective>(element)) { 1637 body << formatv(regionListParserCode, "full"); 1638 if (hasImplicitTermTrait) 1639 body << formatv(regionListEnsureTerminatorParserCode, "full"); 1640 else if (hasSingleBlockTrait) 1641 body << formatv(regionListEnsureSingleBlockParserCode, "full"); 1642 1643 } else if (isa<SuccessorsDirective>(element)) { 1644 body << formatv(successorListParserCode, "full"); 1645 1646 } else if (auto *dir = dyn_cast<TypeDirective>(element)) { 1647 ArgumentLengthKind lengthKind; 1648 StringRef listName = getTypeListName(dir->getArg(), lengthKind); 1649 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { 1650 body << formatv(variadicOfVariadicTypeParserCode, listName); 1651 } else if (lengthKind == ArgumentLengthKind::Variadic) { 1652 body << formatv(variadicTypeParserCode, listName); 1653 } else if (lengthKind == ArgumentLengthKind::Optional) { 1654 body << formatv(optionalTypeParserCode, listName); 1655 } else { 1656 const char *parserCode = 1657 dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode; 1658 TypeSwitch<FormatElement *>(dir->getArg()) 1659 .Case<OperandVariable, ResultVariable>([&](auto operand) { 1660 body << formatv(false, parserCode, 1661 operand->getVar()->constraint.getCppType(), 1662 listName); 1663 }) 1664 .Default([&](auto operand) { 1665 body << formatv(false, parserCode, "::mlir::Type", listName); 1666 }); 1667 } 1668 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) { 1669 ArgumentLengthKind ignored; 1670 body << formatv(functionalTypeParserCode, 1671 getTypeListName(dir->getInputs(), ignored), 1672 getTypeListName(dir->getResults(), ignored)); 1673 } else { 1674 llvm_unreachable("unknown format element"); 1675 } 1676 } 1677 1678 void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) { 1679 // If any of type resolutions use transformed variables, make sure that the 1680 // types of those variables are resolved. 1681 SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables; 1682 FmtContext verifierFCtx; 1683 for (TypeResolution &resolver : 1684 llvm::concat<TypeResolution>(resultTypes, operandTypes)) { 1685 std::optional<StringRef> transformer = resolver.getVarTransformer(); 1686 if (!transformer) 1687 continue; 1688 // Ensure that we don't verify the same variables twice. 1689 const NamedTypeConstraint *variable = resolver.getVariable(); 1690 if (!variable || !verifiedVariables.insert(variable).second) 1691 continue; 1692 1693 auto constraint = variable->constraint; 1694 body << " for (::mlir::Type type : " << variable->name << "Types) {\n" 1695 << " (void)type;\n" 1696 << " if (!(" 1697 << tgfmt(constraint.getConditionTemplate(), 1698 &verifierFCtx.withSelf("type")) 1699 << ")) {\n" 1700 << formatv(" return parser.emitError(parser.getNameLoc()) << " 1701 "\"'{0}' must be {1}, but got \" << type;\n", 1702 variable->name, constraint.getSummary()) 1703 << " }\n" 1704 << " }\n"; 1705 } 1706 1707 // Initialize the set of buildable types. 1708 if (!buildableTypes.empty()) { 1709 FmtContext typeBuilderCtx; 1710 typeBuilderCtx.withBuilder("parser.getBuilder()"); 1711 for (auto &it : buildableTypes) 1712 body << " ::mlir::Type odsBuildableType" << it.second << " = " 1713 << tgfmt(it.first, &typeBuilderCtx) << ";\n"; 1714 } 1715 1716 // Emit the code necessary for a type resolver. 1717 auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) { 1718 if (std::optional<int> val = resolver.getBuilderIdx()) { 1719 body << "odsBuildableType" << *val; 1720 } else if (const NamedTypeConstraint *var = resolver.getVariable()) { 1721 if (std::optional<StringRef> tform = resolver.getVarTransformer()) { 1722 FmtContext fmtContext; 1723 fmtContext.addSubst("_ctxt", "parser.getContext()"); 1724 if (var->isVariadic()) 1725 fmtContext.withSelf(var->name + "Types"); 1726 else 1727 fmtContext.withSelf(var->name + "Types[0]"); 1728 body << tgfmt(*tform, &fmtContext); 1729 } else { 1730 body << var->name << "Types"; 1731 if (!var->isVariadic()) 1732 body << "[0]"; 1733 } 1734 } else if (const NamedAttribute *attr = resolver.getAttribute()) { 1735 if (std::optional<StringRef> tform = resolver.getVarTransformer()) 1736 body << tgfmt(*tform, 1737 &FmtContext().withSelf(attr->name + "Attr.getType()")); 1738 else 1739 body << attr->name << "Attr.getType()"; 1740 } else { 1741 body << curVar << "Types"; 1742 } 1743 }; 1744 1745 // Resolve each of the result types. 1746 if (!infersResultTypes) { 1747 if (allResultTypes) { 1748 body << " result.addTypes(allResultTypes);\n"; 1749 } else { 1750 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { 1751 body << " result.addTypes("; 1752 emitTypeResolver(resultTypes[i], op.getResultName(i)); 1753 body << ");\n"; 1754 } 1755 } 1756 } 1757 1758 // Emit the operand type resolutions. 1759 genParserOperandTypeResolution(op, body, emitTypeResolver); 1760 1761 // Handle return type inference once all operands have been resolved 1762 if (infersResultTypes) 1763 body << formatv(inferReturnTypesParserCode, op.getCppClassName()); 1764 } 1765 1766 void OperationFormat::genParserOperandTypeResolution( 1767 Operator &op, MethodBody &body, 1768 function_ref<void(TypeResolution &, StringRef)> emitTypeResolver) { 1769 // Early exit if there are no operands. 1770 if (op.getNumOperands() == 0) 1771 return; 1772 1773 // Handle the case where all operand types are grouped together with 1774 // "types(operands)". 1775 if (allOperandTypes) { 1776 // If `operands` was specified, use the full operand list directly. 1777 if (allOperands) { 1778 body << " if (parser.resolveOperands(allOperands, allOperandTypes, " 1779 "allOperandLoc, result.operands))\n" 1780 " return ::mlir::failure();\n"; 1781 return; 1782 } 1783 1784 // Otherwise, use llvm::concat to merge the disjoint operand lists together. 1785 // llvm::concat does not allow the case of a single range, so guard it here. 1786 body << " if (parser.resolveOperands("; 1787 if (op.getNumOperands() > 1) { 1788 body << "::llvm::concat<const ::mlir::OpAsmParser::UnresolvedOperand>("; 1789 llvm::interleaveComma(op.getOperands(), body, [&](auto &operand) { 1790 body << operand.name << "Operands"; 1791 }); 1792 body << ")"; 1793 } else { 1794 body << op.operand_begin()->name << "Operands"; 1795 } 1796 body << ", allOperandTypes, parser.getNameLoc(), result.operands))\n" 1797 << " return ::mlir::failure();\n"; 1798 return; 1799 } 1800 1801 // Handle the case where all operands are grouped together with "operands". 1802 if (allOperands) { 1803 body << " if (parser.resolveOperands(allOperands, "; 1804 1805 // Group all of the operand types together to perform the resolution all at 1806 // once. Use llvm::concat to perform the merge. llvm::concat does not allow 1807 // the case of a single range, so guard it here. 1808 if (op.getNumOperands() > 1) { 1809 body << "::llvm::concat<const ::mlir::Type>("; 1810 llvm::interleaveComma( 1811 llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) { 1812 body << "::llvm::ArrayRef<::mlir::Type>("; 1813 emitTypeResolver(operandTypes[i], op.getOperand(i).name); 1814 body << ")"; 1815 }); 1816 body << ")"; 1817 } else { 1818 emitTypeResolver(operandTypes.front(), op.getOperand(0).name); 1819 } 1820 1821 body << ", allOperandLoc, result.operands))\n return " 1822 "::mlir::failure();\n"; 1823 return; 1824 } 1825 1826 // The final case is the one where each of the operands types are resolved 1827 // separately. 1828 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) { 1829 NamedTypeConstraint &operand = op.getOperand(i); 1830 body << " if (parser.resolveOperands(" << operand.name << "Operands, "; 1831 1832 // Resolve the type of this operand. 1833 TypeResolution &operandType = operandTypes[i]; 1834 emitTypeResolver(operandType, operand.name); 1835 1836 body << ", " << operand.name 1837 << "OperandsLoc, result.operands))\n return ::mlir::failure();\n"; 1838 } 1839 } 1840 1841 void OperationFormat::genParserRegionResolution(Operator &op, 1842 MethodBody &body) { 1843 // Check for the case where all regions were parsed. 1844 bool hasAllRegions = llvm::any_of( 1845 elements, [](FormatElement *elt) { return isa<RegionsDirective>(elt); }); 1846 if (hasAllRegions) { 1847 body << " result.addRegions(fullRegions);\n"; 1848 return; 1849 } 1850 1851 // Otherwise, handle each region individually. 1852 for (const NamedRegion ®ion : op.getRegions()) { 1853 if (region.isVariadic()) 1854 body << " result.addRegions(" << region.name << "Regions);\n"; 1855 else 1856 body << " result.addRegion(std::move(" << region.name << "Region));\n"; 1857 } 1858 } 1859 1860 void OperationFormat::genParserSuccessorResolution(Operator &op, 1861 MethodBody &body) { 1862 // Check for the case where all successors were parsed. 1863 bool hasAllSuccessors = llvm::any_of(elements, [](FormatElement *elt) { 1864 return isa<SuccessorsDirective>(elt); 1865 }); 1866 if (hasAllSuccessors) { 1867 body << " result.addSuccessors(fullSuccessors);\n"; 1868 return; 1869 } 1870 1871 // Otherwise, handle each successor individually. 1872 for (const NamedSuccessor &successor : op.getSuccessors()) { 1873 if (successor.isVariadic()) 1874 body << " result.addSuccessors(" << successor.name << "Successors);\n"; 1875 else 1876 body << " result.addSuccessors(" << successor.name << "Successor);\n"; 1877 } 1878 } 1879 1880 void OperationFormat::genParserVariadicSegmentResolution(Operator &op, 1881 MethodBody &body) { 1882 if (!allOperands) { 1883 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { 1884 auto interleaveFn = [&](const NamedTypeConstraint &operand) { 1885 // If the operand is variadic emit the parsed size. 1886 if (operand.isVariableLength()) 1887 body << "static_cast<int32_t>(" << operand.name << "Operands.size())"; 1888 else 1889 body << "1"; 1890 }; 1891 if (op.getDialect().usePropertiesForAttributes()) { 1892 body << "::llvm::copy(::llvm::ArrayRef<int32_t>({"; 1893 llvm::interleaveComma(op.getOperands(), body, interleaveFn); 1894 body << formatv("}), " 1895 "result.getOrAddProperties<{0}::Properties>()." 1896 "operandSegmentSizes.begin());\n", 1897 op.getCppClassName()); 1898 } else { 1899 body << " result.addAttribute(\"operandSegmentSizes\", " 1900 << "parser.getBuilder().getDenseI32ArrayAttr({"; 1901 llvm::interleaveComma(op.getOperands(), body, interleaveFn); 1902 body << "}));\n"; 1903 } 1904 } 1905 for (const NamedTypeConstraint &operand : op.getOperands()) { 1906 if (!operand.isVariadicOfVariadic()) 1907 continue; 1908 if (op.getDialect().usePropertiesForAttributes()) { 1909 body << formatv( 1910 " result.getOrAddProperties<{0}::Properties>().{1} = " 1911 "parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n", 1912 op.getCppClassName(), 1913 operand.constraint.getVariadicOfVariadicSegmentSizeAttr(), 1914 operand.name); 1915 } else { 1916 body << formatv( 1917 " result.addAttribute(\"{0}\", " 1918 "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));" 1919 "\n", 1920 operand.constraint.getVariadicOfVariadicSegmentSizeAttr(), 1921 operand.name); 1922 } 1923 } 1924 } 1925 1926 if (!allResultTypes && 1927 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { 1928 auto interleaveFn = [&](const NamedTypeConstraint &result) { 1929 // If the result is variadic emit the parsed size. 1930 if (result.isVariableLength()) 1931 body << "static_cast<int32_t>(" << result.name << "Types.size())"; 1932 else 1933 body << "1"; 1934 }; 1935 if (op.getDialect().usePropertiesForAttributes()) { 1936 body << "::llvm::copy(::llvm::ArrayRef<int32_t>({"; 1937 llvm::interleaveComma(op.getResults(), body, interleaveFn); 1938 body << formatv("}), " 1939 "result.getOrAddProperties<{0}::Properties>()." 1940 "resultSegmentSizes.begin());\n", 1941 op.getCppClassName()); 1942 } else { 1943 body << " result.addAttribute(\"resultSegmentSizes\", " 1944 << "parser.getBuilder().getDenseI32ArrayAttr({"; 1945 llvm::interleaveComma(op.getResults(), body, interleaveFn); 1946 body << "}));\n"; 1947 } 1948 } 1949 } 1950 1951 //===----------------------------------------------------------------------===// 1952 // PrinterGen 1953 1954 /// The code snippet used to generate a printer call for a region of an 1955 // operation that has the SingleBlockImplicitTerminator trait. 1956 /// 1957 /// {0}: The name of the region. 1958 const char *regionSingleBlockImplicitTerminatorPrinterCode = R"( 1959 { 1960 bool printTerminator = true; 1961 if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{ 1962 printTerminator = !term->getAttrDictionary().empty() || 1963 term->getNumOperands() != 0 || 1964 term->getNumResults() != 0; 1965 } 1966 _odsPrinter.printRegion({0}, /*printEntryBlockArgs=*/true, 1967 /*printBlockTerminators=*/printTerminator); 1968 } 1969 )"; 1970 1971 /// The code snippet used to generate a printer call for an enum that has cases 1972 /// that can't be represented with a keyword. 1973 /// 1974 /// {0}: The name of the enum attribute. 1975 /// {1}: The name of the enum attributes symbolToString function. 1976 const char *enumAttrBeginPrinterCode = R"( 1977 { 1978 auto caseValue = {0}(); 1979 auto caseValueStr = {1}(caseValue); 1980 )"; 1981 1982 /// Generate a check that an optional or default-valued attribute or property 1983 /// has a non-default value. For these purposes, the default value of an 1984 /// optional attribute is its presence, even if the attribute itself has a 1985 /// default value. 1986 static void genNonDefaultValueCheck(MethodBody &body, const Operator &op, 1987 AttributeVariable &attrElement) { 1988 Attribute attr = attrElement.getVar()->attr; 1989 std::string getter = op.getGetterName(attrElement.getVar()->name); 1990 bool optionalAndDefault = attr.isOptional() && attr.hasDefaultValue(); 1991 if (optionalAndDefault) 1992 body << "("; 1993 if (attr.isOptional()) 1994 body << getter << "Attr()"; 1995 if (optionalAndDefault) 1996 body << " && "; 1997 if (attr.hasDefaultValue()) { 1998 FmtContext fctx; 1999 fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())"); 2000 body << getter << "Attr() != " 2001 << tgfmt(attr.getConstBuilderTemplate(), &fctx, 2002 attr.getDefaultValue()); 2003 } 2004 if (optionalAndDefault) 2005 body << ")"; 2006 } 2007 2008 static void genNonDefaultValueCheck(MethodBody &body, const Operator &op, 2009 PropertyVariable &propElement) { 2010 body << op.getGetterName(propElement.getVar()->name) 2011 << "() != " << propElement.getVar()->prop.getDefaultValue(); 2012 } 2013 2014 /// Elide the variadic segment size attributes if necessary. 2015 /// This pushes elided attribute names in `elidedStorage`. 2016 static void genVariadicSegmentElision(OperationFormat &fmt, Operator &op, 2017 MethodBody &body, 2018 const char *elidedStorage) { 2019 if (!fmt.allOperands && 2020 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) 2021 body << " " << elidedStorage << ".push_back(\"operandSegmentSizes\");\n"; 2022 if (!fmt.allResultTypes && 2023 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) 2024 body << " " << elidedStorage << ".push_back(\"resultSegmentSizes\");\n"; 2025 } 2026 2027 /// Generate the printer for the 'prop-dict' directive. 2028 static void genPropDictPrinter(OperationFormat &fmt, Operator &op, 2029 MethodBody &body) { 2030 body << " ::llvm::SmallVector<::llvm::StringRef, 2> elidedProps;\n"; 2031 2032 genVariadicSegmentElision(fmt, op, body, "elidedProps"); 2033 2034 for (const NamedProperty *namedProperty : fmt.usedProperties) 2035 body << " elidedProps.push_back(\"" << namedProperty->name << "\");\n"; 2036 for (const NamedAttribute *namedAttr : fmt.usedAttributes) 2037 body << " elidedProps.push_back(\"" << namedAttr->name << "\");\n"; 2038 2039 // Add code to check attributes for equality with their default values. 2040 // Default-valued attributes will not be printed when their value matches the 2041 // default. 2042 for (const NamedAttribute &namedAttr : op.getAttributes()) { 2043 const Attribute &attr = namedAttr.attr; 2044 if (!attr.isDerivedAttr() && attr.hasDefaultValue()) { 2045 const StringRef &name = namedAttr.name; 2046 FmtContext fctx; 2047 fctx.withBuilder("odsBuilder"); 2048 std::string defaultValue = std::string( 2049 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); 2050 body << " {\n"; 2051 body << " ::mlir::Builder odsBuilder(getContext());\n"; 2052 body << " ::mlir::Attribute attr = " << op.getGetterName(name) 2053 << "Attr();\n"; 2054 body << " if(attr && (attr == " << defaultValue << "))\n"; 2055 body << " elidedProps.push_back(\"" << name << "\");\n"; 2056 body << " }\n"; 2057 } 2058 } 2059 // Similarly, elide default-valued properties. 2060 for (const NamedProperty &prop : op.getProperties()) { 2061 if (prop.prop.hasDefaultValue()) { 2062 body << " if (" << op.getGetterName(prop.name) 2063 << "() == " << prop.prop.getDefaultValue() << ") {"; 2064 body << " elidedProps.push_back(\"" << prop.name << "\");\n"; 2065 body << " }\n"; 2066 } 2067 } 2068 2069 if (fmt.useProperties) { 2070 body << " _odsPrinter << \" \";\n" 2071 << " printProperties(this->getContext(), _odsPrinter, " 2072 "getProperties(), elidedProps);\n"; 2073 } 2074 } 2075 2076 /// Generate the printer for the 'attr-dict' directive. 2077 static void genAttrDictPrinter(OperationFormat &fmt, Operator &op, 2078 MethodBody &body, bool withKeyword) { 2079 body << " ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;\n"; 2080 2081 genVariadicSegmentElision(fmt, op, body, "elidedAttrs"); 2082 2083 for (const StringRef key : fmt.inferredAttributes.keys()) 2084 body << " elidedAttrs.push_back(\"" << key << "\");\n"; 2085 for (const NamedAttribute *attr : fmt.usedAttributes) 2086 body << " elidedAttrs.push_back(\"" << attr->name << "\");\n"; 2087 2088 // Add code to check attributes for equality with their default values. 2089 // Default-valued attributes will not be printed when their value matches the 2090 // default. 2091 for (const NamedAttribute &namedAttr : op.getAttributes()) { 2092 const Attribute &attr = namedAttr.attr; 2093 if (!attr.isDerivedAttr() && attr.hasDefaultValue()) { 2094 const StringRef &name = namedAttr.name; 2095 FmtContext fctx; 2096 fctx.withBuilder("odsBuilder"); 2097 std::string defaultValue = std::string( 2098 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); 2099 body << " {\n"; 2100 body << " ::mlir::Builder odsBuilder(getContext());\n"; 2101 body << " ::mlir::Attribute attr = " << op.getGetterName(name) 2102 << "Attr();\n"; 2103 body << " if(attr && (attr == " << defaultValue << "))\n"; 2104 body << " elidedAttrs.push_back(\"" << name << "\");\n"; 2105 body << " }\n"; 2106 } 2107 } 2108 if (fmt.hasPropDict) 2109 body << " _odsPrinter.printOptionalAttrDict" 2110 << (withKeyword ? "WithKeyword" : "") 2111 << "(llvm::to_vector((*this)->getDiscardableAttrs()), elidedAttrs);\n"; 2112 else 2113 body << " _odsPrinter.printOptionalAttrDict" 2114 << (withKeyword ? "WithKeyword" : "") 2115 << "((*this)->getAttrs(), elidedAttrs);\n"; 2116 } 2117 2118 /// Generate the printer for a literal value. `shouldEmitSpace` is true if a 2119 /// space should be emitted before this element. `lastWasPunctuation` is true if 2120 /// the previous element was a punctuation literal. 2121 static void genLiteralPrinter(StringRef value, MethodBody &body, 2122 bool &shouldEmitSpace, bool &lastWasPunctuation) { 2123 body << " _odsPrinter"; 2124 2125 // Don't insert a space for certain punctuation. 2126 if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation)) 2127 body << " << ' '"; 2128 body << " << \"" << value << "\";\n"; 2129 2130 // Insert a space after certain literals. 2131 shouldEmitSpace = 2132 value.size() != 1 || !StringRef("<({[").contains(value.front()); 2133 lastWasPunctuation = value.front() != '_' && !isalpha(value.front()); 2134 } 2135 2136 /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation` 2137 /// are set to false. 2138 static void genSpacePrinter(bool value, MethodBody &body, bool &shouldEmitSpace, 2139 bool &lastWasPunctuation) { 2140 if (value) { 2141 body << " _odsPrinter << ' ';\n"; 2142 lastWasPunctuation = false; 2143 } else { 2144 lastWasPunctuation = true; 2145 } 2146 shouldEmitSpace = false; 2147 } 2148 2149 /// Generate the printer for a custom directive parameter. 2150 static void genCustomDirectiveParameterPrinter(FormatElement *element, 2151 const Operator &op, 2152 MethodBody &body) { 2153 if (auto *attr = dyn_cast<AttributeVariable>(element)) { 2154 body << op.getGetterName(attr->getVar()->name) << "Attr()"; 2155 2156 } else if (isa<AttrDictDirective>(element)) { 2157 body << "getOperation()->getAttrDictionary()"; 2158 2159 } else if (isa<PropDictDirective>(element)) { 2160 body << "getProperties()"; 2161 2162 } else if (auto *operand = dyn_cast<OperandVariable>(element)) { 2163 body << op.getGetterName(operand->getVar()->name) << "()"; 2164 2165 } else if (auto *region = dyn_cast<RegionVariable>(element)) { 2166 body << op.getGetterName(region->getVar()->name) << "()"; 2167 2168 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) { 2169 body << op.getGetterName(successor->getVar()->name) << "()"; 2170 2171 } else if (auto *dir = dyn_cast<RefDirective>(element)) { 2172 genCustomDirectiveParameterPrinter(dir->getArg(), op, body); 2173 2174 } else if (auto *dir = dyn_cast<TypeDirective>(element)) { 2175 auto *typeOperand = dir->getArg(); 2176 auto *operand = dyn_cast<OperandVariable>(typeOperand); 2177 auto *var = operand ? operand->getVar() 2178 : cast<ResultVariable>(typeOperand)->getVar(); 2179 std::string name = op.getGetterName(var->name); 2180 if (var->isVariadic()) 2181 body << name << "().getTypes()"; 2182 else if (var->isOptional()) 2183 body << formatv("({0}() ? {0}().getType() : ::mlir::Type())", name); 2184 else 2185 body << name << "().getType()"; 2186 2187 } else if (auto *string = dyn_cast<StringElement>(element)) { 2188 FmtContext ctx; 2189 ctx.withBuilder("::mlir::Builder(getContext())"); 2190 ctx.addSubst("_ctxt", "getContext()"); 2191 body << tgfmt(string->getValue(), &ctx); 2192 2193 } else if (auto *property = dyn_cast<PropertyVariable>(element)) { 2194 FmtContext ctx; 2195 const NamedProperty *namedProperty = property->getVar(); 2196 ctx.addSubst("_storage", "getProperties()." + namedProperty->name); 2197 body << tgfmt(namedProperty->prop.getConvertFromStorageCall(), &ctx); 2198 } else { 2199 llvm_unreachable("unknown custom directive parameter"); 2200 } 2201 } 2202 2203 /// Generate the printer for a custom directive. 2204 static void genCustomDirectivePrinter(CustomDirective *customDir, 2205 const Operator &op, MethodBody &body) { 2206 body << " print" << customDir->getName() << "(_odsPrinter, *this"; 2207 for (FormatElement *param : customDir->getArguments()) { 2208 body << ", "; 2209 genCustomDirectiveParameterPrinter(param, op, body); 2210 } 2211 body << ");\n"; 2212 } 2213 2214 /// Generate the printer for a region with the given variable name. 2215 static void genRegionPrinter(const Twine ®ionName, MethodBody &body, 2216 bool hasImplicitTermTrait) { 2217 if (hasImplicitTermTrait) 2218 body << formatv(regionSingleBlockImplicitTerminatorPrinterCode, regionName); 2219 else 2220 body << " _odsPrinter.printRegion(" << regionName << ");\n"; 2221 } 2222 static void genVariadicRegionPrinter(const Twine ®ionListName, 2223 MethodBody &body, 2224 bool hasImplicitTermTrait) { 2225 body << " llvm::interleaveComma(" << regionListName 2226 << ", _odsPrinter, [&](::mlir::Region ®ion) {\n "; 2227 genRegionPrinter("region", body, hasImplicitTermTrait); 2228 body << " });\n"; 2229 } 2230 2231 /// Generate the C++ for an operand to a (*-)type directive. 2232 static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op, 2233 MethodBody &body, 2234 bool useArrayRef = true) { 2235 if (isa<OperandsDirective>(arg)) 2236 return body << "getOperation()->getOperandTypes()"; 2237 if (isa<ResultsDirective>(arg)) 2238 return body << "getOperation()->getResultTypes()"; 2239 auto *operand = dyn_cast<OperandVariable>(arg); 2240 auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar(); 2241 if (var->isVariadicOfVariadic()) 2242 return body << formatv("{0}().join().getTypes()", 2243 op.getGetterName(var->name)); 2244 if (var->isVariadic()) 2245 return body << op.getGetterName(var->name) << "().getTypes()"; 2246 if (var->isOptional()) 2247 return body << formatv( 2248 "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : " 2249 "::llvm::ArrayRef<::mlir::Type>())", 2250 op.getGetterName(var->name)); 2251 if (useArrayRef) 2252 return body << "::llvm::ArrayRef<::mlir::Type>(" 2253 << op.getGetterName(var->name) << "().getType())"; 2254 return body << op.getGetterName(var->name) << "().getType()"; 2255 } 2256 2257 /// Generate the printer for an enum attribute. 2258 static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op, 2259 MethodBody &body) { 2260 Attribute baseAttr = var->attr.getBaseAttr(); 2261 const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr); 2262 std::vector<EnumAttrCase> cases = enumAttr.getAllCases(); 2263 2264 body << formatv(enumAttrBeginPrinterCode, 2265 (var->attr.isOptional() ? "*" : "") + 2266 op.getGetterName(var->name), 2267 enumAttr.getSymbolToStringFnName()); 2268 2269 // Get a string containing all of the cases that can't be represented with a 2270 // keyword. 2271 BitVector nonKeywordCases(cases.size()); 2272 for (auto it : llvm::enumerate(cases)) { 2273 if (!canFormatStringAsKeyword(it.value().getStr())) 2274 nonKeywordCases.set(it.index()); 2275 } 2276 2277 // Otherwise if this is a bit enum attribute, don't allow cases that may 2278 // overlap with other cases. For simplicity sake, only allow cases with a 2279 // single bit value. 2280 if (enumAttr.isBitEnum()) { 2281 for (auto it : llvm::enumerate(cases)) { 2282 int64_t value = it.value().getValue(); 2283 if (value < 0 || !llvm::isPowerOf2_64(value)) 2284 nonKeywordCases.set(it.index()); 2285 } 2286 } 2287 2288 // If there are any cases that can't be used with a keyword, switch on the 2289 // case value to determine when to print in the string form. 2290 if (nonKeywordCases.any()) { 2291 body << " switch (caseValue) {\n"; 2292 StringRef cppNamespace = enumAttr.getCppNamespace(); 2293 StringRef enumName = enumAttr.getEnumClassName(); 2294 for (auto it : llvm::enumerate(cases)) { 2295 if (nonKeywordCases.test(it.index())) 2296 continue; 2297 StringRef symbol = it.value().getSymbol(); 2298 body << formatv(" case {0}::{1}::{2}:\n", cppNamespace, enumName, 2299 llvm::isDigit(symbol.front()) ? ("_" + symbol) : symbol); 2300 } 2301 body << " _odsPrinter << caseValueStr;\n" 2302 " break;\n" 2303 " default:\n" 2304 " _odsPrinter << '\"' << caseValueStr << '\"';\n" 2305 " break;\n" 2306 " }\n" 2307 " }\n"; 2308 return; 2309 } 2310 2311 body << " _odsPrinter << caseValueStr;\n" 2312 " }\n"; 2313 } 2314 2315 /// Generate the check for the anchor of an optional group. 2316 static void genOptionalGroupPrinterAnchor(FormatElement *anchor, 2317 const Operator &op, 2318 MethodBody &body) { 2319 TypeSwitch<FormatElement *>(anchor) 2320 .Case<OperandVariable, ResultVariable>([&](auto *element) { 2321 const NamedTypeConstraint *var = element->getVar(); 2322 std::string name = op.getGetterName(var->name); 2323 if (var->isOptional()) 2324 body << name << "()"; 2325 else if (var->isVariadic()) 2326 body << "!" << name << "().empty()"; 2327 }) 2328 .Case([&](RegionVariable *element) { 2329 const NamedRegion *var = element->getVar(); 2330 std::string name = op.getGetterName(var->name); 2331 // TODO: Add a check for optional regions here when ODS supports it. 2332 body << "!" << name << "().empty()"; 2333 }) 2334 .Case([&](TypeDirective *element) { 2335 genOptionalGroupPrinterAnchor(element->getArg(), op, body); 2336 }) 2337 .Case([&](FunctionalTypeDirective *element) { 2338 genOptionalGroupPrinterAnchor(element->getInputs(), op, body); 2339 }) 2340 .Case([&](AttributeVariable *element) { 2341 // Consider a default-valued attribute as present if it's not the 2342 // default value and an optional one present if it is set. 2343 genNonDefaultValueCheck(body, op, *element); 2344 }) 2345 .Case([&](PropertyVariable *element) { 2346 genNonDefaultValueCheck(body, op, *element); 2347 }) 2348 .Case([&](CustomDirective *ele) { 2349 body << '('; 2350 llvm::interleave( 2351 ele->getArguments(), body, 2352 [&](FormatElement *child) { 2353 body << '('; 2354 genOptionalGroupPrinterAnchor(child, op, body); 2355 body << ')'; 2356 }, 2357 " || "); 2358 body << ')'; 2359 }); 2360 } 2361 2362 void collect(FormatElement *element, 2363 SmallVectorImpl<VariableElement *> &variables) { 2364 TypeSwitch<FormatElement *>(element) 2365 .Case([&](VariableElement *var) { variables.emplace_back(var); }) 2366 .Case([&](CustomDirective *ele) { 2367 for (FormatElement *arg : ele->getArguments()) 2368 collect(arg, variables); 2369 }) 2370 .Case([&](OptionalElement *ele) { 2371 for (FormatElement *arg : ele->getThenElements()) 2372 collect(arg, variables); 2373 for (FormatElement *arg : ele->getElseElements()) 2374 collect(arg, variables); 2375 }) 2376 .Case([&](FunctionalTypeDirective *funcType) { 2377 collect(funcType->getInputs(), variables); 2378 collect(funcType->getResults(), variables); 2379 }) 2380 .Case([&](OIListElement *oilist) { 2381 for (ArrayRef<FormatElement *> arg : oilist->getParsingElements()) 2382 for (FormatElement *arg : arg) 2383 collect(arg, variables); 2384 }); 2385 } 2386 2387 void OperationFormat::genElementPrinter(FormatElement *element, 2388 MethodBody &body, Operator &op, 2389 bool &shouldEmitSpace, 2390 bool &lastWasPunctuation) { 2391 if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) 2392 return genLiteralPrinter(literal->getSpelling(), body, shouldEmitSpace, 2393 lastWasPunctuation); 2394 2395 // Emit a whitespace element. 2396 if (auto *space = dyn_cast<WhitespaceElement>(element)) { 2397 if (space->getValue() == "\\n") { 2398 body << " _odsPrinter.printNewline();\n"; 2399 } else { 2400 genSpacePrinter(!space->getValue().empty(), body, shouldEmitSpace, 2401 lastWasPunctuation); 2402 } 2403 return; 2404 } 2405 2406 // Emit an optional group. 2407 if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) { 2408 // Emit the check for the presence of the anchor element. 2409 FormatElement *anchor = optional->getAnchor(); 2410 body << " if ("; 2411 if (optional->isInverted()) 2412 body << "!"; 2413 genOptionalGroupPrinterAnchor(anchor, op, body); 2414 body << ") {\n"; 2415 body.indent(); 2416 2417 // If the anchor is a unit attribute, we don't need to print it. When 2418 // parsing, we will add this attribute if this group is present. 2419 ArrayRef<FormatElement *> thenElements = optional->getThenElements(); 2420 ArrayRef<FormatElement *> elseElements = optional->getElseElements(); 2421 FormatElement *elidedAnchorElement = nullptr; 2422 auto *anchorAttr = dyn_cast<AttributeLikeVariable>(anchor); 2423 if (anchorAttr && anchorAttr != thenElements.front() && 2424 (elseElements.empty() || anchorAttr != elseElements.front()) && 2425 anchorAttr->isUnit()) { 2426 elidedAnchorElement = anchorAttr; 2427 } 2428 auto genElementPrinters = [&](ArrayRef<FormatElement *> elements) { 2429 for (FormatElement *childElement : elements) { 2430 if (childElement != elidedAnchorElement) { 2431 genElementPrinter(childElement, body, op, shouldEmitSpace, 2432 lastWasPunctuation); 2433 } 2434 } 2435 }; 2436 2437 // Emit each of the elements. 2438 genElementPrinters(thenElements); 2439 body << "}"; 2440 2441 // Emit each of the else elements. 2442 if (!elseElements.empty()) { 2443 body << " else {\n"; 2444 genElementPrinters(elseElements); 2445 body << "}"; 2446 } 2447 2448 body.unindent() << "\n"; 2449 return; 2450 } 2451 2452 // Emit the OIList 2453 if (auto *oilist = dyn_cast<OIListElement>(element)) { 2454 for (auto clause : oilist->getClauses()) { 2455 LiteralElement *lelement = std::get<0>(clause); 2456 ArrayRef<FormatElement *> pelement = std::get<1>(clause); 2457 2458 SmallVector<VariableElement *> vars; 2459 for (FormatElement *el : pelement) 2460 collect(el, vars); 2461 body << " if (false"; 2462 for (VariableElement *var : vars) { 2463 TypeSwitch<FormatElement *>(var) 2464 .Case([&](AttributeVariable *attrEle) { 2465 body << " || ("; 2466 genNonDefaultValueCheck(body, op, *attrEle); 2467 body << ")"; 2468 }) 2469 .Case([&](PropertyVariable *propEle) { 2470 body << " || ("; 2471 genNonDefaultValueCheck(body, op, *propEle); 2472 body << ")"; 2473 }) 2474 .Case([&](OperandVariable *ele) { 2475 if (ele->getVar()->isVariadic()) { 2476 body << " || " << op.getGetterName(ele->getVar()->name) 2477 << "().size()"; 2478 } else { 2479 body << " || " << op.getGetterName(ele->getVar()->name) << "()"; 2480 } 2481 }) 2482 .Case([&](ResultVariable *ele) { 2483 if (ele->getVar()->isVariadic()) { 2484 body << " || " << op.getGetterName(ele->getVar()->name) 2485 << "().size()"; 2486 } else { 2487 body << " || " << op.getGetterName(ele->getVar()->name) << "()"; 2488 } 2489 }) 2490 .Case([&](RegionVariable *reg) { 2491 body << " || " << op.getGetterName(reg->getVar()->name) << "()"; 2492 }); 2493 } 2494 2495 body << ") {\n"; 2496 genLiteralPrinter(lelement->getSpelling(), body, shouldEmitSpace, 2497 lastWasPunctuation); 2498 if (oilist->getUnitVariableParsingElement(pelement) == nullptr) { 2499 for (FormatElement *element : pelement) 2500 genElementPrinter(element, body, op, shouldEmitSpace, 2501 lastWasPunctuation); 2502 } 2503 body << " }\n"; 2504 } 2505 return; 2506 } 2507 2508 // Emit the attribute dictionary. 2509 if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) { 2510 genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword()); 2511 lastWasPunctuation = false; 2512 return; 2513 } 2514 2515 // Emit the property dictionary. 2516 if (isa<PropDictDirective>(element)) { 2517 genPropDictPrinter(*this, op, body); 2518 lastWasPunctuation = false; 2519 return; 2520 } 2521 2522 // Optionally insert a space before the next element. The AttrDict printer 2523 // already adds a space as necessary. 2524 if (shouldEmitSpace || !lastWasPunctuation) 2525 body << " _odsPrinter << ' ';\n"; 2526 lastWasPunctuation = false; 2527 shouldEmitSpace = true; 2528 2529 if (auto *attr = dyn_cast<AttributeVariable>(element)) { 2530 const NamedAttribute *var = attr->getVar(); 2531 2532 // If we are formatting as an enum, symbolize the attribute as a string. 2533 if (canFormatEnumAttr(var)) 2534 return genEnumAttrPrinter(var, op, body); 2535 2536 // If we are formatting as a symbol name, handle it as a symbol name. 2537 if (shouldFormatSymbolNameAttr(var)) { 2538 body << " _odsPrinter.printSymbolName(" << op.getGetterName(var->name) 2539 << "Attr().getValue());\n"; 2540 return; 2541 } 2542 2543 // Elide the attribute type if it is buildable. 2544 if (attr->getTypeBuilder()) 2545 body << " _odsPrinter.printAttributeWithoutType(" 2546 << op.getGetterName(var->name) << "Attr());\n"; 2547 else if (attr->shouldBeQualified() || 2548 var->attr.getStorageType() == "::mlir::Attribute") 2549 body << " _odsPrinter.printAttribute(" << op.getGetterName(var->name) 2550 << "Attr());\n"; 2551 else 2552 body << "_odsPrinter.printStrippedAttrOrType(" 2553 << op.getGetterName(var->name) << "Attr());\n"; 2554 } else if (auto *property = dyn_cast<PropertyVariable>(element)) { 2555 const NamedProperty *var = property->getVar(); 2556 FmtContext fmtContext; 2557 fmtContext.addSubst("_printer", "_odsPrinter"); 2558 fmtContext.addSubst("_ctxt", "getContext()"); 2559 fmtContext.addSubst("_storage", "getProperties()." + var->name); 2560 body << tgfmt(var->prop.getPrinterCall(), &fmtContext) << ";\n"; 2561 } else if (auto *operand = dyn_cast<OperandVariable>(element)) { 2562 if (operand->getVar()->isVariadicOfVariadic()) { 2563 body << " ::llvm::interleaveComma(" 2564 << op.getGetterName(operand->getVar()->name) 2565 << "(), _odsPrinter, [&](const auto &operands) { _odsPrinter << " 2566 "\"(\" << operands << " 2567 "\")\"; });\n"; 2568 2569 } else if (operand->getVar()->isOptional()) { 2570 body << " if (::mlir::Value value = " 2571 << op.getGetterName(operand->getVar()->name) << "())\n" 2572 << " _odsPrinter << value;\n"; 2573 } else { 2574 body << " _odsPrinter << " << op.getGetterName(operand->getVar()->name) 2575 << "();\n"; 2576 } 2577 } else if (auto *region = dyn_cast<RegionVariable>(element)) { 2578 const NamedRegion *var = region->getVar(); 2579 std::string name = op.getGetterName(var->name); 2580 if (var->isVariadic()) { 2581 genVariadicRegionPrinter(name + "()", body, hasImplicitTermTrait); 2582 } else { 2583 genRegionPrinter(name + "()", body, hasImplicitTermTrait); 2584 } 2585 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) { 2586 const NamedSuccessor *var = successor->getVar(); 2587 std::string name = op.getGetterName(var->name); 2588 if (var->isVariadic()) 2589 body << " ::llvm::interleaveComma(" << name << "(), _odsPrinter);\n"; 2590 else 2591 body << " _odsPrinter << " << name << "();\n"; 2592 } else if (auto *dir = dyn_cast<CustomDirective>(element)) { 2593 genCustomDirectivePrinter(dir, op, body); 2594 } else if (isa<OperandsDirective>(element)) { 2595 body << " _odsPrinter << getOperation()->getOperands();\n"; 2596 } else if (isa<RegionsDirective>(element)) { 2597 genVariadicRegionPrinter("getOperation()->getRegions()", body, 2598 hasImplicitTermTrait); 2599 } else if (isa<SuccessorsDirective>(element)) { 2600 body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), " 2601 "_odsPrinter);\n"; 2602 } else if (auto *dir = dyn_cast<TypeDirective>(element)) { 2603 if (auto *operand = dyn_cast<OperandVariable>(dir->getArg())) { 2604 if (operand->getVar()->isVariadicOfVariadic()) { 2605 body << formatv( 2606 " ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, " 2607 "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << " 2608 "types << \")\"; });\n", 2609 op.getGetterName(operand->getVar()->name)); 2610 return; 2611 } 2612 } 2613 const NamedTypeConstraint *var = nullptr; 2614 { 2615 if (auto *operand = dyn_cast<OperandVariable>(dir->getArg())) 2616 var = operand->getVar(); 2617 else if (auto *operand = dyn_cast<ResultVariable>(dir->getArg())) 2618 var = operand->getVar(); 2619 } 2620 if (var && !var->isVariadicOfVariadic() && !var->isVariadic() && 2621 !var->isOptional()) { 2622 StringRef cppType = var->constraint.getCppType(); 2623 if (dir->shouldBeQualified()) { 2624 body << " _odsPrinter << " << op.getGetterName(var->name) 2625 << "().getType();\n"; 2626 return; 2627 } 2628 body << " {\n" 2629 << " auto type = " << op.getGetterName(var->name) 2630 << "().getType();\n" 2631 << " if (auto validType = ::llvm::dyn_cast<" << cppType 2632 << ">(type))\n" 2633 << " _odsPrinter.printStrippedAttrOrType(validType);\n" 2634 << " else\n" 2635 << " _odsPrinter << type;\n" 2636 << " }\n"; 2637 return; 2638 } 2639 body << " _odsPrinter << "; 2640 genTypeOperandPrinter(dir->getArg(), op, body, /*useArrayRef=*/false) 2641 << ";\n"; 2642 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) { 2643 body << " _odsPrinter.printFunctionalType("; 2644 genTypeOperandPrinter(dir->getInputs(), op, body) << ", "; 2645 genTypeOperandPrinter(dir->getResults(), op, body) << ");\n"; 2646 } else { 2647 llvm_unreachable("unknown format element"); 2648 } 2649 } 2650 2651 void OperationFormat::genPrinter(Operator &op, OpClass &opClass) { 2652 auto *method = opClass.addMethod( 2653 "void", "print", 2654 MethodParameter("::mlir::OpAsmPrinter &", "_odsPrinter")); 2655 auto &body = method->body(); 2656 2657 // Flags for if we should emit a space, and if the last element was 2658 // punctuation. 2659 bool shouldEmitSpace = true, lastWasPunctuation = false; 2660 for (FormatElement *element : elements) 2661 genElementPrinter(element, body, op, shouldEmitSpace, lastWasPunctuation); 2662 } 2663 2664 //===----------------------------------------------------------------------===// 2665 // OpFormatParser 2666 //===----------------------------------------------------------------------===// 2667 2668 /// Function to find an element within the given range that has the same name as 2669 /// 'name'. 2670 template <typename RangeT> 2671 static auto findArg(RangeT &&range, StringRef name) { 2672 auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; }); 2673 return it != range.end() ? &*it : nullptr; 2674 } 2675 2676 namespace { 2677 /// This class implements a parser for an instance of an operation assembly 2678 /// format. 2679 class OpFormatParser : public FormatParser { 2680 public: 2681 OpFormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op) 2682 : FormatParser(mgr, op.getLoc()[0]), fmt(format), op(op), 2683 seenOperandTypes(op.getNumOperands()), 2684 seenResultTypes(op.getNumResults()) {} 2685 2686 protected: 2687 /// Verify the format elements. 2688 LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override; 2689 /// Verify the arguments to a custom directive. 2690 LogicalResult 2691 verifyCustomDirectiveArguments(SMLoc loc, 2692 ArrayRef<FormatElement *> arguments) override; 2693 /// Verify the elements of an optional group. 2694 LogicalResult verifyOptionalGroupElements(SMLoc loc, 2695 ArrayRef<FormatElement *> elements, 2696 FormatElement *anchor) override; 2697 LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element, 2698 bool isAnchor); 2699 2700 LogicalResult markQualified(SMLoc loc, FormatElement *element) override; 2701 2702 /// Parse an operation variable. 2703 FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name, 2704 Context ctx) override; 2705 /// Parse an operation format directive. 2706 FailureOr<FormatElement *> 2707 parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override; 2708 2709 private: 2710 /// This struct represents a type resolution instance. It includes a specific 2711 /// type as well as an optional transformer to apply to that type in order to 2712 /// properly resolve the type of a variable. 2713 struct TypeResolutionInstance { 2714 ConstArgument resolver; 2715 std::optional<StringRef> transformer; 2716 }; 2717 2718 /// Verify the state of operation attributes within the format. 2719 LogicalResult verifyAttributes(SMLoc loc, ArrayRef<FormatElement *> elements); 2720 2721 /// Verify that attributes elements aren't followed by colon literals. 2722 LogicalResult verifyAttributeColonType(SMLoc loc, 2723 ArrayRef<FormatElement *> elements); 2724 /// Verify that the attribute dictionary directive isn't followed by a region. 2725 LogicalResult verifyAttrDictRegion(SMLoc loc, 2726 ArrayRef<FormatElement *> elements); 2727 2728 /// Verify the state of operation operands within the format. 2729 LogicalResult 2730 verifyOperands(SMLoc loc, 2731 StringMap<TypeResolutionInstance> &variableTyResolver); 2732 2733 /// Verify the state of operation regions within the format. 2734 LogicalResult verifyRegions(SMLoc loc); 2735 2736 /// Verify the state of operation results within the format. 2737 LogicalResult 2738 verifyResults(SMLoc loc, 2739 StringMap<TypeResolutionInstance> &variableTyResolver); 2740 2741 /// Verify the state of operation successors within the format. 2742 LogicalResult verifySuccessors(SMLoc loc); 2743 2744 LogicalResult verifyOIListElements(SMLoc loc, 2745 ArrayRef<FormatElement *> elements); 2746 2747 /// Given the values of an `AllTypesMatch` trait, check for inferable type 2748 /// resolution. 2749 void handleAllTypesMatchConstraint( 2750 ArrayRef<StringRef> values, 2751 StringMap<TypeResolutionInstance> &variableTyResolver); 2752 /// Check for inferable type resolution given all operands, and or results, 2753 /// have the same type. If 'includeResults' is true, the results also have the 2754 /// same type as all of the operands. 2755 void handleSameTypesConstraint( 2756 StringMap<TypeResolutionInstance> &variableTyResolver, 2757 bool includeResults); 2758 /// Check for inferable type resolution based on another operand, result, or 2759 /// attribute. 2760 void handleTypesMatchConstraint( 2761 StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def); 2762 2763 /// Returns an argument or attribute with the given name that has been seen 2764 /// within the format. 2765 ConstArgument findSeenArg(StringRef name); 2766 2767 /// Parse the various different directives. 2768 FailureOr<FormatElement *> parsePropDictDirective(SMLoc loc, Context context); 2769 FailureOr<FormatElement *> parseAttrDictDirective(SMLoc loc, Context context, 2770 bool withKeyword); 2771 FailureOr<FormatElement *> parseFunctionalTypeDirective(SMLoc loc, 2772 Context context); 2773 FailureOr<FormatElement *> parseOIListDirective(SMLoc loc, Context context); 2774 LogicalResult verifyOIListParsingElement(FormatElement *element, SMLoc loc); 2775 FailureOr<FormatElement *> parseOperandsDirective(SMLoc loc, Context context); 2776 FailureOr<FormatElement *> parseRegionsDirective(SMLoc loc, Context context); 2777 FailureOr<FormatElement *> parseResultsDirective(SMLoc loc, Context context); 2778 FailureOr<FormatElement *> parseSuccessorsDirective(SMLoc loc, 2779 Context context); 2780 FailureOr<FormatElement *> parseTypeDirective(SMLoc loc, Context context); 2781 FailureOr<FormatElement *> parseTypeDirectiveOperand(SMLoc loc, 2782 bool isRefChild = false); 2783 2784 //===--------------------------------------------------------------------===// 2785 // Fields 2786 //===--------------------------------------------------------------------===// 2787 2788 OperationFormat &fmt; 2789 Operator &op; 2790 2791 // The following are various bits of format state used for verification 2792 // during parsing. 2793 bool hasAttrDict = false; 2794 bool hasPropDict = false; 2795 bool hasAllRegions = false, hasAllSuccessors = false; 2796 bool canInferResultTypes = false; 2797 llvm::SmallBitVector seenOperandTypes, seenResultTypes; 2798 llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs; 2799 llvm::DenseSet<const NamedTypeConstraint *> seenOperands; 2800 llvm::DenseSet<const NamedRegion *> seenRegions; 2801 llvm::DenseSet<const NamedSuccessor *> seenSuccessors; 2802 llvm::SmallSetVector<const NamedProperty *, 8> seenProperties; 2803 }; 2804 } // namespace 2805 2806 LogicalResult OpFormatParser::verify(SMLoc loc, 2807 ArrayRef<FormatElement *> elements) { 2808 // Check that the attribute dictionary is in the format. 2809 if (!hasAttrDict) 2810 return emitError(loc, "'attr-dict' directive not found in " 2811 "custom assembly format"); 2812 2813 // Check for any type traits that we can use for inferring types. 2814 StringMap<TypeResolutionInstance> variableTyResolver; 2815 for (const Trait &trait : op.getTraits()) { 2816 const Record &def = trait.getDef(); 2817 if (def.isSubClassOf("AllTypesMatch")) { 2818 handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"), 2819 variableTyResolver); 2820 } else if (def.getName() == "SameTypeOperands") { 2821 handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false); 2822 } else if (def.getName() == "SameOperandsAndResultType") { 2823 handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); 2824 } else if (def.isSubClassOf("TypesMatchWith")) { 2825 handleTypesMatchConstraint(variableTyResolver, def); 2826 } else if (!op.allResultTypesKnown()) { 2827 // This doesn't check the name directly to handle 2828 // DeclareOpInterfaceMethods<InferTypeOpInterface> 2829 // and the like. 2830 // TODO: Add hasCppInterface check. 2831 if (auto name = def.getValueAsOptionalString("cppInterfaceName")) { 2832 if (*name == "InferTypeOpInterface" && 2833 def.getValueAsString("cppNamespace") == "::mlir") 2834 canInferResultTypes = true; 2835 } 2836 } 2837 } 2838 2839 // Verify the state of the various operation components. 2840 if (failed(verifyAttributes(loc, elements)) || 2841 failed(verifyResults(loc, variableTyResolver)) || 2842 failed(verifyOperands(loc, variableTyResolver)) || 2843 failed(verifyRegions(loc)) || failed(verifySuccessors(loc)) || 2844 failed(verifyOIListElements(loc, elements))) 2845 return failure(); 2846 2847 // Collect the set of used attributes in the format. 2848 fmt.usedAttributes = std::move(seenAttrs); 2849 fmt.usedProperties = std::move(seenProperties); 2850 2851 // Set whether prop-dict is used in the format 2852 fmt.hasPropDict = hasPropDict; 2853 return success(); 2854 } 2855 2856 LogicalResult 2857 OpFormatParser::verifyAttributes(SMLoc loc, 2858 ArrayRef<FormatElement *> elements) { 2859 // Check that there are no `:` literals after an attribute without a constant 2860 // type. The attribute grammar contains an optional trailing colon type, which 2861 // can lead to unexpected and generally unintended behavior. Given that, it is 2862 // better to just error out here instead. 2863 if (failed(verifyAttributeColonType(loc, elements))) 2864 return failure(); 2865 // Check that there are no region variables following an attribute dicitonary. 2866 // Both start with `{` and so the optional attribute dictionary can cause 2867 // format ambiguities. 2868 if (failed(verifyAttrDictRegion(loc, elements))) 2869 return failure(); 2870 2871 // Check for VariadicOfVariadic variables. The segment attribute of those 2872 // variables will be infered. 2873 for (const NamedTypeConstraint *var : seenOperands) { 2874 if (var->constraint.isVariadicOfVariadic()) { 2875 fmt.inferredAttributes.insert( 2876 var->constraint.getVariadicOfVariadicSegmentSizeAttr()); 2877 } 2878 } 2879 2880 return success(); 2881 } 2882 2883 /// Returns whether the single format element is optionally parsed. 2884 static bool isOptionallyParsed(FormatElement *el) { 2885 if (auto *attrVar = dyn_cast<AttributeVariable>(el)) { 2886 Attribute attr = attrVar->getVar()->attr; 2887 return attr.isOptional() || attr.hasDefaultValue(); 2888 } 2889 if (auto *propVar = dyn_cast<PropertyVariable>(el)) { 2890 const Property &prop = propVar->getVar()->prop; 2891 return prop.hasDefaultValue() && prop.hasOptionalParser(); 2892 } 2893 if (auto *operandVar = dyn_cast<OperandVariable>(el)) { 2894 const NamedTypeConstraint *operand = operandVar->getVar(); 2895 return operand->isOptional() || operand->isVariadic() || 2896 operand->isVariadicOfVariadic(); 2897 } 2898 if (auto *successorVar = dyn_cast<SuccessorVariable>(el)) 2899 return successorVar->getVar()->isVariadic(); 2900 if (auto *regionVar = dyn_cast<RegionVariable>(el)) 2901 return regionVar->getVar()->isVariadic(); 2902 return isa<WhitespaceElement, AttrDictDirective>(el); 2903 } 2904 2905 /// Scan the given range of elements from the start for an invalid format 2906 /// element that satisfies `isInvalid`, skipping any optionally-parsed elements. 2907 /// If an optional group is encountered, this function recurses into the 'then' 2908 /// and 'else' elements to check if they are invalid. Returns `success` if the 2909 /// range is known to be valid or `std::nullopt` if scanning reached the end. 2910 /// 2911 /// Since the guard element of an optional group is required, this function 2912 /// accepts an optional element pointer to mark it as required. 2913 static std::optional<LogicalResult> checkRangeForElement( 2914 FormatElement *base, 2915 function_ref<bool(FormatElement *, FormatElement *)> isInvalid, 2916 iterator_range<ArrayRef<FormatElement *>::iterator> elementRange, 2917 FormatElement *optionalGuard = nullptr) { 2918 for (FormatElement *element : elementRange) { 2919 // If we encounter an invalid element, return an error. 2920 if (isInvalid(base, element)) 2921 return failure(); 2922 2923 // Recurse on optional groups. 2924 if (auto *optional = dyn_cast<OptionalElement>(element)) { 2925 if (std::optional<LogicalResult> result = checkRangeForElement( 2926 base, isInvalid, optional->getThenElements(), 2927 // The optional group guard is required for the group. 2928 optional->getThenElements().front())) 2929 if (failed(*result)) 2930 return failure(); 2931 if (std::optional<LogicalResult> result = checkRangeForElement( 2932 base, isInvalid, optional->getElseElements())) 2933 if (failed(*result)) 2934 return failure(); 2935 // Skip the optional group. 2936 continue; 2937 } 2938 2939 // Skip optionally parsed elements. 2940 if (element != optionalGuard && isOptionallyParsed(element)) 2941 continue; 2942 2943 // We found a closing element that is valid. 2944 return success(); 2945 } 2946 // Return std::nullopt to indicate that we reached the end. 2947 return std::nullopt; 2948 } 2949 2950 /// For the given elements, check whether any attributes are followed by a colon 2951 /// literal, resulting in an ambiguous assembly format. Returns a non-null 2952 /// attribute if verification of said attribute reached the end of the range. 2953 /// Returns null if all attribute elements are verified. 2954 static FailureOr<FormatElement *> verifyAdjacentElements( 2955 function_ref<bool(FormatElement *)> isBase, 2956 function_ref<bool(FormatElement *, FormatElement *)> isInvalid, 2957 ArrayRef<FormatElement *> elements) { 2958 for (auto *it = elements.begin(), *e = elements.end(); it != e; ++it) { 2959 // The current attribute being verified. 2960 FormatElement *base; 2961 2962 if (isBase(*it)) { 2963 base = *it; 2964 } else if (auto *optional = dyn_cast<OptionalElement>(*it)) { 2965 // Recurse on optional groups. 2966 FailureOr<FormatElement *> thenResult = verifyAdjacentElements( 2967 isBase, isInvalid, optional->getThenElements()); 2968 if (failed(thenResult)) 2969 return failure(); 2970 FailureOr<FormatElement *> elseResult = verifyAdjacentElements( 2971 isBase, isInvalid, optional->getElseElements()); 2972 if (failed(elseResult)) 2973 return failure(); 2974 // If either optional group has an unverified attribute, save it. 2975 // Otherwise, move on to the next element. 2976 if (!(base = *thenResult) && !(base = *elseResult)) 2977 continue; 2978 } else { 2979 continue; 2980 } 2981 2982 // Verify subsequent elements for potential ambiguities. 2983 if (std::optional<LogicalResult> result = 2984 checkRangeForElement(base, isInvalid, {std::next(it), e})) { 2985 if (failed(*result)) 2986 return failure(); 2987 } else { 2988 // Since we reached the end, return the attribute as unverified. 2989 return base; 2990 } 2991 } 2992 // All attribute elements are known to be verified. 2993 return nullptr; 2994 } 2995 2996 LogicalResult 2997 OpFormatParser::verifyAttributeColonType(SMLoc loc, 2998 ArrayRef<FormatElement *> elements) { 2999 auto isBase = [](FormatElement *el) { 3000 auto *attr = dyn_cast<AttributeVariable>(el); 3001 if (!attr) 3002 return false; 3003 // Check only attributes without type builders or that are known to call 3004 // the generic attribute parser. 3005 return !attr->getTypeBuilder() && 3006 (attr->shouldBeQualified() || 3007 attr->getVar()->attr.getStorageType() == "::mlir::Attribute"); 3008 }; 3009 auto isInvalid = [&](FormatElement *base, FormatElement *el) { 3010 auto *literal = dyn_cast<LiteralElement>(el); 3011 if (!literal || literal->getSpelling() != ":") 3012 return false; 3013 // If we encounter `:`, the range is known to be invalid. 3014 (void)emitError( 3015 loc, formatv("format ambiguity caused by `:` literal found after " 3016 "attribute `{0}` which does not have a buildable type", 3017 cast<AttributeVariable>(base)->getVar()->name)); 3018 return true; 3019 }; 3020 return verifyAdjacentElements(isBase, isInvalid, elements); 3021 } 3022 3023 LogicalResult 3024 OpFormatParser::verifyAttrDictRegion(SMLoc loc, 3025 ArrayRef<FormatElement *> elements) { 3026 auto isBase = [](FormatElement *el) { 3027 if (auto *attrDict = dyn_cast<AttrDictDirective>(el)) 3028 return !attrDict->isWithKeyword(); 3029 return false; 3030 }; 3031 auto isInvalid = [&](FormatElement *base, FormatElement *el) { 3032 auto *region = dyn_cast<RegionVariable>(el); 3033 if (!region) 3034 return false; 3035 (void)emitErrorAndNote( 3036 loc, 3037 formatv("format ambiguity caused by `attr-dict` directive " 3038 "followed by region `{0}`", 3039 region->getVar()->name), 3040 "try using `attr-dict-with-keyword` instead"); 3041 return true; 3042 }; 3043 return verifyAdjacentElements(isBase, isInvalid, elements); 3044 } 3045 3046 LogicalResult OpFormatParser::verifyOperands( 3047 SMLoc loc, StringMap<TypeResolutionInstance> &variableTyResolver) { 3048 // Check that all of the operands are within the format, and their types can 3049 // be inferred. 3050 auto &buildableTypes = fmt.buildableTypes; 3051 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) { 3052 NamedTypeConstraint &operand = op.getOperand(i); 3053 3054 // Check that the operand itself is in the format. 3055 if (!fmt.allOperands && !seenOperands.count(&operand)) { 3056 return emitErrorAndNote(loc, 3057 "operand #" + Twine(i) + ", named '" + 3058 operand.name + "', not found", 3059 "suggest adding a '$" + operand.name + 3060 "' directive to the custom assembly format"); 3061 } 3062 3063 // Check that the operand type is in the format, or that it can be inferred. 3064 if (fmt.allOperandTypes || seenOperandTypes.test(i)) 3065 continue; 3066 3067 // Check to see if we can infer this type from another variable. 3068 auto varResolverIt = variableTyResolver.find(op.getOperand(i).name); 3069 if (varResolverIt != variableTyResolver.end()) { 3070 TypeResolutionInstance &resolver = varResolverIt->second; 3071 fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer); 3072 continue; 3073 } 3074 3075 // Similarly to results, allow a custom builder for resolving the type if 3076 // we aren't using the 'operands' directive. 3077 std::optional<StringRef> builder = operand.constraint.getBuilderCall(); 3078 if (!builder || (fmt.allOperands && operand.isVariableLength())) { 3079 return emitErrorAndNote( 3080 loc, 3081 "type of operand #" + Twine(i) + ", named '" + operand.name + 3082 "', is not buildable and a buildable type cannot be inferred", 3083 "suggest adding a type constraint to the operation or adding a " 3084 "'type($" + 3085 operand.name + ")' directive to the " + "custom assembly format"); 3086 } 3087 auto it = buildableTypes.insert({*builder, buildableTypes.size()}); 3088 fmt.operandTypes[i].setBuilderIdx(it.first->second); 3089 } 3090 return success(); 3091 } 3092 3093 LogicalResult OpFormatParser::verifyRegions(SMLoc loc) { 3094 // Check that all of the regions are within the format. 3095 if (hasAllRegions) 3096 return success(); 3097 3098 for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) { 3099 const NamedRegion ®ion = op.getRegion(i); 3100 if (!seenRegions.count(®ion)) { 3101 return emitErrorAndNote(loc, 3102 "region #" + Twine(i) + ", named '" + 3103 region.name + "', not found", 3104 "suggest adding a '$" + region.name + 3105 "' directive to the custom assembly format"); 3106 } 3107 } 3108 return success(); 3109 } 3110 3111 LogicalResult OpFormatParser::verifyResults( 3112 SMLoc loc, StringMap<TypeResolutionInstance> &variableTyResolver) { 3113 // If we format all of the types together, there is nothing to check. 3114 if (fmt.allResultTypes) 3115 return success(); 3116 3117 // If no result types are specified and we can infer them, infer all result 3118 // types 3119 if (op.getNumResults() > 0 && seenResultTypes.count() == 0 && 3120 canInferResultTypes) { 3121 fmt.infersResultTypes = true; 3122 return success(); 3123 } 3124 3125 // Check that all of the result types can be inferred. 3126 auto &buildableTypes = fmt.buildableTypes; 3127 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { 3128 if (seenResultTypes.test(i)) 3129 continue; 3130 3131 // Check to see if we can infer this type from another variable. 3132 auto varResolverIt = variableTyResolver.find(op.getResultName(i)); 3133 if (varResolverIt != variableTyResolver.end()) { 3134 TypeResolutionInstance resolver = varResolverIt->second; 3135 fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer); 3136 continue; 3137 } 3138 3139 // If the result is not variable length, allow for the case where the type 3140 // has a builder that we can use. 3141 NamedTypeConstraint &result = op.getResult(i); 3142 std::optional<StringRef> builder = result.constraint.getBuilderCall(); 3143 if (!builder || result.isVariableLength()) { 3144 return emitErrorAndNote( 3145 loc, 3146 "type of result #" + Twine(i) + ", named '" + result.name + 3147 "', is not buildable and a buildable type cannot be inferred", 3148 "suggest adding a type constraint to the operation or adding a " 3149 "'type($" + 3150 result.name + ")' directive to the " + "custom assembly format"); 3151 } 3152 // Note in the format that this result uses the custom builder. 3153 auto it = buildableTypes.insert({*builder, buildableTypes.size()}); 3154 fmt.resultTypes[i].setBuilderIdx(it.first->second); 3155 } 3156 return success(); 3157 } 3158 3159 LogicalResult OpFormatParser::verifySuccessors(SMLoc loc) { 3160 // Check that all of the successors are within the format. 3161 if (hasAllSuccessors) 3162 return success(); 3163 3164 for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) { 3165 const NamedSuccessor &successor = op.getSuccessor(i); 3166 if (!seenSuccessors.count(&successor)) { 3167 return emitErrorAndNote(loc, 3168 "successor #" + Twine(i) + ", named '" + 3169 successor.name + "', not found", 3170 "suggest adding a '$" + successor.name + 3171 "' directive to the custom assembly format"); 3172 } 3173 } 3174 return success(); 3175 } 3176 3177 LogicalResult 3178 OpFormatParser::verifyOIListElements(SMLoc loc, 3179 ArrayRef<FormatElement *> elements) { 3180 // Check that all of the successors are within the format. 3181 SmallVector<StringRef> prohibitedLiterals; 3182 for (FormatElement *it : elements) { 3183 if (auto *oilist = dyn_cast<OIListElement>(it)) { 3184 if (!prohibitedLiterals.empty()) { 3185 // We just saw an oilist element in last iteration. Literals should not 3186 // match. 3187 for (LiteralElement *literal : oilist->getLiteralElements()) { 3188 if (find(prohibitedLiterals, literal->getSpelling()) != 3189 prohibitedLiterals.end()) { 3190 return emitError( 3191 loc, "format ambiguity because " + literal->getSpelling() + 3192 " is used in two adjacent oilist elements."); 3193 } 3194 } 3195 } 3196 for (LiteralElement *literal : oilist->getLiteralElements()) 3197 prohibitedLiterals.push_back(literal->getSpelling()); 3198 } else if (auto *literal = dyn_cast<LiteralElement>(it)) { 3199 if (find(prohibitedLiterals, literal->getSpelling()) != 3200 prohibitedLiterals.end()) { 3201 return emitError( 3202 loc, 3203 "format ambiguity because " + literal->getSpelling() + 3204 " is used both in oilist element and the adjacent literal."); 3205 } 3206 prohibitedLiterals.clear(); 3207 } else { 3208 prohibitedLiterals.clear(); 3209 } 3210 } 3211 return success(); 3212 } 3213 3214 void OpFormatParser::handleAllTypesMatchConstraint( 3215 ArrayRef<StringRef> values, 3216 StringMap<TypeResolutionInstance> &variableTyResolver) { 3217 for (unsigned i = 0, e = values.size(); i != e; ++i) { 3218 // Check to see if this value matches a resolved operand or result type. 3219 ConstArgument arg = findSeenArg(values[i]); 3220 if (!arg) 3221 continue; 3222 3223 // Mark this value as the type resolver for the other variables. 3224 for (unsigned j = 0; j != i; ++j) 3225 variableTyResolver[values[j]] = {arg, std::nullopt}; 3226 for (unsigned j = i + 1; j != e; ++j) 3227 variableTyResolver[values[j]] = {arg, std::nullopt}; 3228 } 3229 } 3230 3231 void OpFormatParser::handleSameTypesConstraint( 3232 StringMap<TypeResolutionInstance> &variableTyResolver, 3233 bool includeResults) { 3234 const NamedTypeConstraint *resolver = nullptr; 3235 int resolvedIt = -1; 3236 3237 // Check to see if there is an operand or result to use for the resolution. 3238 if ((resolvedIt = seenOperandTypes.find_first()) != -1) 3239 resolver = &op.getOperand(resolvedIt); 3240 else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1) 3241 resolver = &op.getResult(resolvedIt); 3242 else 3243 return; 3244 3245 // Set the resolvers for each operand and result. 3246 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) 3247 if (!seenOperandTypes.test(i)) 3248 variableTyResolver[op.getOperand(i).name] = {resolver, std::nullopt}; 3249 if (includeResults) { 3250 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) 3251 if (!seenResultTypes.test(i)) 3252 variableTyResolver[op.getResultName(i)] = {resolver, std::nullopt}; 3253 } 3254 } 3255 3256 void OpFormatParser::handleTypesMatchConstraint( 3257 StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) { 3258 StringRef lhsName = def.getValueAsString("lhs"); 3259 StringRef rhsName = def.getValueAsString("rhs"); 3260 StringRef transformer = def.getValueAsString("transformer"); 3261 if (ConstArgument arg = findSeenArg(lhsName)) 3262 variableTyResolver[rhsName] = {arg, transformer}; 3263 } 3264 3265 ConstArgument OpFormatParser::findSeenArg(StringRef name) { 3266 if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name)) 3267 return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr; 3268 if (const NamedTypeConstraint *arg = findArg(op.getResults(), name)) 3269 return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr; 3270 if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) 3271 return seenAttrs.count(attr) ? attr : nullptr; 3272 return nullptr; 3273 } 3274 3275 FailureOr<FormatElement *> 3276 OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) { 3277 // Check that the parsed argument is something actually registered on the op. 3278 // Attributes 3279 if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) { 3280 if (ctx == TypeDirectiveContext) 3281 return emitError( 3282 loc, "attributes cannot be used as children to a `type` directive"); 3283 if (ctx == RefDirectiveContext) { 3284 if (!seenAttrs.count(attr)) 3285 return emitError(loc, "attribute '" + name + 3286 "' must be bound before it is referenced"); 3287 } else if (!seenAttrs.insert(attr)) { 3288 return emitError(loc, "attribute '" + name + "' is already bound"); 3289 } 3290 3291 return create<AttributeVariable>(attr); 3292 } 3293 3294 if (const NamedProperty *property = findArg(op.getProperties(), name)) { 3295 if (ctx == TypeDirectiveContext) 3296 return emitError( 3297 loc, "properties cannot be used as children to a `type` directive"); 3298 if (ctx == RefDirectiveContext) { 3299 if (!seenProperties.count(property)) 3300 return emitError(loc, "property '" + name + 3301 "' must be bound before it is referenced"); 3302 } else { 3303 if (!seenProperties.insert(property)) 3304 return emitError(loc, "property '" + name + "' is already bound"); 3305 } 3306 3307 return create<PropertyVariable>(property); 3308 } 3309 3310 // Operands 3311 if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) { 3312 if (ctx == TopLevelContext || ctx == CustomDirectiveContext) { 3313 if (fmt.allOperands || !seenOperands.insert(operand).second) 3314 return emitError(loc, "operand '" + name + "' is already bound"); 3315 } else if (ctx == RefDirectiveContext && !seenOperands.count(operand)) { 3316 return emitError(loc, "operand '" + name + 3317 "' must be bound before it is referenced"); 3318 } 3319 return create<OperandVariable>(operand); 3320 } 3321 // Regions 3322 if (const NamedRegion *region = findArg(op.getRegions(), name)) { 3323 if (ctx == TopLevelContext || ctx == CustomDirectiveContext) { 3324 if (hasAllRegions || !seenRegions.insert(region).second) 3325 return emitError(loc, "region '" + name + "' is already bound"); 3326 } else if (ctx == RefDirectiveContext && !seenRegions.count(region)) { 3327 return emitError(loc, "region '" + name + 3328 "' must be bound before it is referenced"); 3329 } else { 3330 return emitError(loc, "regions can only be used at the top level"); 3331 } 3332 return create<RegionVariable>(region); 3333 } 3334 // Results. 3335 if (const auto *result = findArg(op.getResults(), name)) { 3336 if (ctx != TypeDirectiveContext) 3337 return emitError(loc, "result variables can can only be used as a child " 3338 "to a 'type' directive"); 3339 return create<ResultVariable>(result); 3340 } 3341 // Successors. 3342 if (const auto *successor = findArg(op.getSuccessors(), name)) { 3343 if (ctx == TopLevelContext || ctx == CustomDirectiveContext) { 3344 if (hasAllSuccessors || !seenSuccessors.insert(successor).second) 3345 return emitError(loc, "successor '" + name + "' is already bound"); 3346 } else if (ctx == RefDirectiveContext && !seenSuccessors.count(successor)) { 3347 return emitError(loc, "successor '" + name + 3348 "' must be bound before it is referenced"); 3349 } else { 3350 return emitError(loc, "successors can only be used at the top level"); 3351 } 3352 3353 return create<SuccessorVariable>(successor); 3354 } 3355 return emitError(loc, "expected variable to refer to an argument, region, " 3356 "result, or successor"); 3357 } 3358 3359 FailureOr<FormatElement *> 3360 OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, 3361 Context ctx) { 3362 switch (kind) { 3363 case FormatToken::kw_prop_dict: 3364 return parsePropDictDirective(loc, ctx); 3365 case FormatToken::kw_attr_dict: 3366 return parseAttrDictDirective(loc, ctx, 3367 /*withKeyword=*/false); 3368 case FormatToken::kw_attr_dict_w_keyword: 3369 return parseAttrDictDirective(loc, ctx, 3370 /*withKeyword=*/true); 3371 case FormatToken::kw_functional_type: 3372 return parseFunctionalTypeDirective(loc, ctx); 3373 case FormatToken::kw_operands: 3374 return parseOperandsDirective(loc, ctx); 3375 case FormatToken::kw_regions: 3376 return parseRegionsDirective(loc, ctx); 3377 case FormatToken::kw_results: 3378 return parseResultsDirective(loc, ctx); 3379 case FormatToken::kw_successors: 3380 return parseSuccessorsDirective(loc, ctx); 3381 case FormatToken::kw_type: 3382 return parseTypeDirective(loc, ctx); 3383 case FormatToken::kw_oilist: 3384 return parseOIListDirective(loc, ctx); 3385 3386 default: 3387 return emitError(loc, "unsupported directive kind"); 3388 } 3389 } 3390 3391 FailureOr<FormatElement *> 3392 OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context, 3393 bool withKeyword) { 3394 if (context == TypeDirectiveContext) 3395 return emitError(loc, "'attr-dict' directive can only be used as a " 3396 "top-level directive"); 3397 3398 if (context == RefDirectiveContext) { 3399 if (!hasAttrDict) 3400 return emitError(loc, "'ref' of 'attr-dict' is not bound by a prior " 3401 "'attr-dict' directive"); 3402 3403 // Otherwise, this is a top-level context. 3404 } else { 3405 if (hasAttrDict) 3406 return emitError(loc, "'attr-dict' directive has already been seen"); 3407 hasAttrDict = true; 3408 } 3409 3410 return create<AttrDictDirective>(withKeyword); 3411 } 3412 3413 FailureOr<FormatElement *> 3414 OpFormatParser::parsePropDictDirective(SMLoc loc, Context context) { 3415 if (context == TypeDirectiveContext) 3416 return emitError(loc, "'prop-dict' directive can only be used as a " 3417 "top-level directive"); 3418 3419 if (context == RefDirectiveContext) 3420 llvm::report_fatal_error("'ref' of 'prop-dict' unsupported"); 3421 // Otherwise, this is a top-level context. 3422 3423 if (hasPropDict) 3424 return emitError(loc, "'prop-dict' directive has already been seen"); 3425 hasPropDict = true; 3426 3427 return create<PropDictDirective>(); 3428 } 3429 3430 LogicalResult OpFormatParser::verifyCustomDirectiveArguments( 3431 SMLoc loc, ArrayRef<FormatElement *> arguments) { 3432 for (FormatElement *argument : arguments) { 3433 if (!isa<AttrDictDirective, PropDictDirective, AttributeVariable, 3434 OperandVariable, PropertyVariable, RefDirective, RegionVariable, 3435 SuccessorVariable, StringElement, TypeDirective>(argument)) { 3436 // TODO: FormatElement should have location info attached. 3437 return emitError(loc, "only variables and types may be used as " 3438 "parameters to a custom directive"); 3439 } 3440 if (auto *type = dyn_cast<TypeDirective>(argument)) { 3441 if (!isa<OperandVariable, ResultVariable>(type->getArg())) { 3442 return emitError(loc, "type directives within a custom directive may " 3443 "only refer to variables"); 3444 } 3445 } 3446 } 3447 return success(); 3448 } 3449 3450 FailureOr<FormatElement *> 3451 OpFormatParser::parseFunctionalTypeDirective(SMLoc loc, Context context) { 3452 if (context != TopLevelContext) 3453 return emitError( 3454 loc, "'functional-type' is only valid as a top-level directive"); 3455 3456 // Parse the main operand. 3457 FailureOr<FormatElement *> inputs, results; 3458 if (failed(parseToken(FormatToken::l_paren, 3459 "expected '(' before argument list")) || 3460 failed(inputs = parseTypeDirectiveOperand(loc)) || 3461 failed(parseToken(FormatToken::comma, 3462 "expected ',' after inputs argument")) || 3463 failed(results = parseTypeDirectiveOperand(loc)) || 3464 failed( 3465 parseToken(FormatToken::r_paren, "expected ')' after argument list"))) 3466 return failure(); 3467 return create<FunctionalTypeDirective>(*inputs, *results); 3468 } 3469 3470 FailureOr<FormatElement *> 3471 OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) { 3472 if (context == RefDirectiveContext) { 3473 if (!fmt.allOperands) 3474 return emitError(loc, "'ref' of 'operands' is not bound by a prior " 3475 "'operands' directive"); 3476 3477 } else if (context == TopLevelContext || context == CustomDirectiveContext) { 3478 if (fmt.allOperands || !seenOperands.empty()) 3479 return emitError(loc, "'operands' directive creates overlap in format"); 3480 fmt.allOperands = true; 3481 } 3482 return create<OperandsDirective>(); 3483 } 3484 3485 FailureOr<FormatElement *> 3486 OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) { 3487 if (context == TypeDirectiveContext) 3488 return emitError(loc, "'regions' is only valid as a top-level directive"); 3489 if (context == RefDirectiveContext) { 3490 if (!hasAllRegions) 3491 return emitError(loc, "'ref' of 'regions' is not bound by a prior " 3492 "'regions' directive"); 3493 3494 // Otherwise, this is a TopLevel directive. 3495 } else { 3496 if (hasAllRegions || !seenRegions.empty()) 3497 return emitError(loc, "'regions' directive creates overlap in format"); 3498 hasAllRegions = true; 3499 } 3500 return create<RegionsDirective>(); 3501 } 3502 3503 FailureOr<FormatElement *> 3504 OpFormatParser::parseResultsDirective(SMLoc loc, Context context) { 3505 if (context != TypeDirectiveContext) 3506 return emitError(loc, "'results' directive can can only be used as a child " 3507 "to a 'type' directive"); 3508 return create<ResultsDirective>(); 3509 } 3510 3511 FailureOr<FormatElement *> 3512 OpFormatParser::parseSuccessorsDirective(SMLoc loc, Context context) { 3513 if (context == TypeDirectiveContext) 3514 return emitError(loc, 3515 "'successors' is only valid as a top-level directive"); 3516 if (context == RefDirectiveContext) { 3517 if (!hasAllSuccessors) 3518 return emitError(loc, "'ref' of 'successors' is not bound by a prior " 3519 "'successors' directive"); 3520 3521 // Otherwise, this is a TopLevel directive. 3522 } else { 3523 if (hasAllSuccessors || !seenSuccessors.empty()) 3524 return emitError(loc, "'successors' directive creates overlap in format"); 3525 hasAllSuccessors = true; 3526 } 3527 return create<SuccessorsDirective>(); 3528 } 3529 3530 FailureOr<FormatElement *> 3531 OpFormatParser::parseOIListDirective(SMLoc loc, Context context) { 3532 if (failed(parseToken(FormatToken::l_paren, 3533 "expected '(' before oilist argument list"))) 3534 return failure(); 3535 std::vector<FormatElement *> literalElements; 3536 std::vector<std::vector<FormatElement *>> parsingElements; 3537 do { 3538 FailureOr<FormatElement *> lelement = parseLiteral(context); 3539 if (failed(lelement)) 3540 return failure(); 3541 literalElements.push_back(*lelement); 3542 parsingElements.emplace_back(); 3543 std::vector<FormatElement *> &currParsingElements = parsingElements.back(); 3544 while (peekToken().getKind() != FormatToken::pipe && 3545 peekToken().getKind() != FormatToken::r_paren) { 3546 FailureOr<FormatElement *> pelement = parseElement(context); 3547 if (failed(pelement) || 3548 failed(verifyOIListParsingElement(*pelement, loc))) 3549 return failure(); 3550 currParsingElements.push_back(*pelement); 3551 } 3552 if (peekToken().getKind() == FormatToken::pipe) { 3553 consumeToken(); 3554 continue; 3555 } 3556 if (peekToken().getKind() == FormatToken::r_paren) { 3557 consumeToken(); 3558 break; 3559 } 3560 } while (true); 3561 3562 return create<OIListElement>(std::move(literalElements), 3563 std::move(parsingElements)); 3564 } 3565 3566 LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element, 3567 SMLoc loc) { 3568 SmallVector<VariableElement *> vars; 3569 collect(element, vars); 3570 for (VariableElement *elem : vars) { 3571 LogicalResult res = 3572 TypeSwitch<FormatElement *, LogicalResult>(elem) 3573 // Only optional attributes can be within an oilist parsing group. 3574 .Case([&](AttributeVariable *attrEle) { 3575 if (!attrEle->getVar()->attr.isOptional() && 3576 !attrEle->getVar()->attr.hasDefaultValue()) 3577 return emitError(loc, "only optional attributes can be used in " 3578 "an oilist parsing group"); 3579 return success(); 3580 }) 3581 // Only optional properties can be within an oilist parsing group. 3582 .Case([&](PropertyVariable *propEle) { 3583 if (!propEle->getVar()->prop.hasDefaultValue()) 3584 return emitError( 3585 loc, 3586 "only default-valued or optional properties can be used in " 3587 "an olist parsing group"); 3588 return success(); 3589 }) 3590 // Only optional-like(i.e. variadic) operands can be within an 3591 // oilist parsing group. 3592 .Case([&](OperandVariable *ele) { 3593 if (!ele->getVar()->isVariableLength()) 3594 return emitError(loc, "only variable length operands can be " 3595 "used within an oilist parsing group"); 3596 return success(); 3597 }) 3598 // Only optional-like(i.e. variadic) results can be within an oilist 3599 // parsing group. 3600 .Case([&](ResultVariable *ele) { 3601 if (!ele->getVar()->isVariableLength()) 3602 return emitError(loc, "only variable length results can be " 3603 "used within an oilist parsing group"); 3604 return success(); 3605 }) 3606 .Case([&](RegionVariable *) { return success(); }) 3607 .Default([&](FormatElement *) { 3608 return emitError(loc, 3609 "only literals, types, and variables can be " 3610 "used within an oilist group"); 3611 }); 3612 if (failed(res)) 3613 return failure(); 3614 } 3615 return success(); 3616 } 3617 3618 FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc, 3619 Context context) { 3620 if (context == TypeDirectiveContext) 3621 return emitError(loc, "'type' cannot be used as a child of another `type`"); 3622 3623 bool isRefChild = context == RefDirectiveContext; 3624 FailureOr<FormatElement *> operand; 3625 if (failed(parseToken(FormatToken::l_paren, 3626 "expected '(' before argument list")) || 3627 failed(operand = parseTypeDirectiveOperand(loc, isRefChild)) || 3628 failed( 3629 parseToken(FormatToken::r_paren, "expected ')' after argument list"))) 3630 return failure(); 3631 3632 return create<TypeDirective>(*operand); 3633 } 3634 3635 LogicalResult OpFormatParser::markQualified(SMLoc loc, FormatElement *element) { 3636 return TypeSwitch<FormatElement *, LogicalResult>(element) 3637 .Case<AttributeVariable, TypeDirective>([](auto *element) { 3638 element->setShouldBeQualified(); 3639 return success(); 3640 }) 3641 .Default([&](auto *element) { 3642 return this->emitError( 3643 loc, 3644 "'qualified' directive expects an attribute or a `type` directive"); 3645 }); 3646 } 3647 3648 FailureOr<FormatElement *> 3649 OpFormatParser::parseTypeDirectiveOperand(SMLoc loc, bool isRefChild) { 3650 FailureOr<FormatElement *> result = parseElement(TypeDirectiveContext); 3651 if (failed(result)) 3652 return failure(); 3653 3654 FormatElement *element = *result; 3655 if (isa<LiteralElement>(element)) 3656 return emitError( 3657 loc, "'type' directive operand expects variable or directive operand"); 3658 3659 if (auto *var = dyn_cast<OperandVariable>(element)) { 3660 unsigned opIdx = var->getVar() - op.operand_begin(); 3661 if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.test(opIdx))) 3662 return emitError(loc, "'type' of '" + var->getVar()->name + 3663 "' is already bound"); 3664 if (isRefChild && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx))) 3665 return emitError(loc, "'ref' of 'type($" + var->getVar()->name + 3666 ")' is not bound by a prior 'type' directive"); 3667 seenOperandTypes.set(opIdx); 3668 } else if (auto *var = dyn_cast<ResultVariable>(element)) { 3669 unsigned resIdx = var->getVar() - op.result_begin(); 3670 if (!isRefChild && (fmt.allResultTypes || seenResultTypes.test(resIdx))) 3671 return emitError(loc, "'type' of '" + var->getVar()->name + 3672 "' is already bound"); 3673 if (isRefChild && !(fmt.allResultTypes || seenResultTypes.test(resIdx))) 3674 return emitError(loc, "'ref' of 'type($" + var->getVar()->name + 3675 ")' is not bound by a prior 'type' directive"); 3676 seenResultTypes.set(resIdx); 3677 } else if (isa<OperandsDirective>(&*element)) { 3678 if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.any())) 3679 return emitError(loc, "'operands' 'type' is already bound"); 3680 if (isRefChild && !fmt.allOperandTypes) 3681 return emitError(loc, "'ref' of 'type(operands)' is not bound by a prior " 3682 "'type' directive"); 3683 fmt.allOperandTypes = true; 3684 } else if (isa<ResultsDirective>(&*element)) { 3685 if (!isRefChild && (fmt.allResultTypes || seenResultTypes.any())) 3686 return emitError(loc, "'results' 'type' is already bound"); 3687 if (isRefChild && !fmt.allResultTypes) 3688 return emitError(loc, "'ref' of 'type(results)' is not bound by a prior " 3689 "'type' directive"); 3690 fmt.allResultTypes = true; 3691 } else { 3692 return emitError(loc, "invalid argument to 'type' directive"); 3693 } 3694 return element; 3695 } 3696 3697 LogicalResult OpFormatParser::verifyOptionalGroupElements( 3698 SMLoc loc, ArrayRef<FormatElement *> elements, FormatElement *anchor) { 3699 for (FormatElement *element : elements) { 3700 if (failed(verifyOptionalGroupElement(loc, element, element == anchor))) 3701 return failure(); 3702 } 3703 return success(); 3704 } 3705 3706 LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc, 3707 FormatElement *element, 3708 bool isAnchor) { 3709 return TypeSwitch<FormatElement *, LogicalResult>(element) 3710 // All attributes can be within the optional group, but only optional 3711 // attributes can be the anchor. 3712 .Case([&](AttributeVariable *attrEle) { 3713 Attribute attr = attrEle->getVar()->attr; 3714 if (isAnchor && !(attr.isOptional() || attr.hasDefaultValue())) 3715 return emitError(loc, "only optional or default-valued attributes " 3716 "can be used to anchor an optional group"); 3717 return success(); 3718 }) 3719 // All properties can be within the optional group, but only optional 3720 // properties can be the anchor. 3721 .Case([&](PropertyVariable *propEle) { 3722 Property prop = propEle->getVar()->prop; 3723 if (isAnchor && !(prop.hasDefaultValue() && prop.hasOptionalParser())) 3724 return emitError(loc, "only properties with default values " 3725 "that can be optionally parsed " 3726 "can be used to anchor an optional group"); 3727 return success(); 3728 }) 3729 // Only optional-like(i.e. variadic) operands can be within an optional 3730 // group. 3731 .Case([&](OperandVariable *ele) { 3732 if (!ele->getVar()->isVariableLength()) 3733 return emitError(loc, "only variable length operands can be used " 3734 "within an optional group"); 3735 return success(); 3736 }) 3737 // Only optional-like(i.e. variadic) results can be within an optional 3738 // group. 3739 .Case([&](ResultVariable *ele) { 3740 if (!ele->getVar()->isVariableLength()) 3741 return emitError(loc, "only variable length results can be used " 3742 "within an optional group"); 3743 return success(); 3744 }) 3745 .Case([&](RegionVariable *) { 3746 // TODO: When ODS has proper support for marking "optional" regions, add 3747 // a check here. 3748 return success(); 3749 }) 3750 .Case([&](TypeDirective *ele) { 3751 return verifyOptionalGroupElement(loc, ele->getArg(), 3752 /*isAnchor=*/false); 3753 }) 3754 .Case([&](FunctionalTypeDirective *ele) { 3755 if (failed(verifyOptionalGroupElement(loc, ele->getInputs(), 3756 /*isAnchor=*/false))) 3757 return failure(); 3758 return verifyOptionalGroupElement(loc, ele->getResults(), 3759 /*isAnchor=*/false); 3760 }) 3761 .Case([&](CustomDirective *ele) { 3762 if (!isAnchor) 3763 return success(); 3764 // Verify each child as being valid in an optional group. They are all 3765 // potential anchors if the custom directive was marked as one. 3766 for (FormatElement *child : ele->getArguments()) { 3767 if (isa<RefDirective>(child)) 3768 continue; 3769 if (failed(verifyOptionalGroupElement(loc, child, /*isAnchor=*/true))) 3770 return failure(); 3771 } 3772 return success(); 3773 }) 3774 // Literals, whitespace, and custom directives may be used, but they can't 3775 // anchor the group. 3776 .Case<LiteralElement, WhitespaceElement, OptionalElement>( 3777 [&](FormatElement *) { 3778 if (isAnchor) 3779 return emitError(loc, "only variables and types can be used " 3780 "to anchor an optional group"); 3781 return success(); 3782 }) 3783 .Default([&](FormatElement *) { 3784 return emitError(loc, "only literals, types, and variables can be " 3785 "used within an optional group"); 3786 }); 3787 } 3788 3789 //===----------------------------------------------------------------------===// 3790 // Interface 3791 //===----------------------------------------------------------------------===// 3792 3793 void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass, 3794 bool hasProperties) { 3795 // TODO: Operator doesn't expose all necessary functionality via 3796 // the const interface. 3797 Operator &op = const_cast<Operator &>(constOp); 3798 if (!op.hasAssemblyFormat()) 3799 return; 3800 3801 // Parse the format description. 3802 llvm::SourceMgr mgr; 3803 mgr.AddNewSourceBuffer( 3804 llvm::MemoryBuffer::getMemBuffer(op.getAssemblyFormat()), SMLoc()); 3805 OperationFormat format(op, hasProperties); 3806 OpFormatParser parser(mgr, format, op); 3807 FailureOr<std::vector<FormatElement *>> elements = parser.parse(); 3808 if (failed(elements)) { 3809 // Exit the process if format errors are treated as fatal. 3810 if (formatErrorIsFatal) { 3811 // Invoke the interrupt handlers to run the file cleanup handlers. 3812 llvm::sys::RunInterruptHandlers(); 3813 std::exit(1); 3814 } 3815 return; 3816 } 3817 format.elements = std::move(*elements); 3818 3819 // Generate the printer and parser based on the parsed format. 3820 format.genParser(op, opClass); 3821 format.genPrinter(op, opClass); 3822 } 3823