xref: /llvm-project/mlir/tools/mlir-tblgen/OpFormatGen.cpp (revision 378e1793379c9c63a4265ecf55c47308410ed25d)
1 //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "OpFormatGen.h"
10 #include "FormatGen.h"
11 #include "OpClass.h"
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/TableGen/Class.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "mlir/TableGen/Trait.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Signals.h"
24 #include "llvm/Support/SourceMgr.h"
25 #include "llvm/TableGen/Record.h"
26 
27 #define DEBUG_TYPE "mlir-tblgen-opformatgen"
28 
29 using namespace mlir;
30 using namespace mlir::tblgen;
31 using llvm::formatv;
32 using llvm::Record;
33 using llvm::StringMap;
34 
35 //===----------------------------------------------------------------------===//
36 // VariableElement
37 
38 namespace {
39 /// This class represents an instance of an op variable element. A variable
40 /// refers to something registered on the operation itself, e.g. an operand,
41 /// result, attribute, region, or successor.
42 template <typename VarT, VariableElement::Kind VariableKind>
43 class OpVariableElement : public VariableElementBase<VariableKind> {
44 public:
45   using Base = OpVariableElement<VarT, VariableKind>;
46 
47   /// Create an op variable element with the variable value.
48   OpVariableElement(const VarT *var) : var(var) {}
49 
50   /// Get the variable.
51   const VarT *getVar() const { return var; }
52 
53 protected:
54   /// The op variable, e.g. a type or attribute constraint.
55   const VarT *var;
56 };
57 
58 /// This class represents a variable that refers to an attribute argument.
59 struct AttributeVariable
60     : public OpVariableElement<NamedAttribute, VariableElement::Attribute> {
61   using Base::Base;
62 
63   /// Return the constant builder call for the type of this attribute, or
64   /// std::nullopt if it doesn't have one.
65   std::optional<StringRef> getTypeBuilder() const {
66     std::optional<Type> attrType = var->attr.getValueType();
67     return attrType ? attrType->getBuilderCall() : std::nullopt;
68   }
69 
70   /// Indicate if this attribute is printed "qualified" (that is it is
71   /// prefixed with the `#dialect.mnemonic`).
72   bool shouldBeQualified() { return shouldBeQualifiedFlag; }
73   void setShouldBeQualified(bool qualified = true) {
74     shouldBeQualifiedFlag = qualified;
75   }
76 
77 private:
78   bool shouldBeQualifiedFlag = false;
79 };
80 
81 /// This class represents a variable that refers to an operand argument.
82 using OperandVariable =
83     OpVariableElement<NamedTypeConstraint, VariableElement::Operand>;
84 
85 /// This class represents a variable that refers to a result.
86 using ResultVariable =
87     OpVariableElement<NamedTypeConstraint, VariableElement::Result>;
88 
89 /// This class represents a variable that refers to a region.
90 using RegionVariable = OpVariableElement<NamedRegion, VariableElement::Region>;
91 
92 /// This class represents a variable that refers to a successor.
93 using SuccessorVariable =
94     OpVariableElement<NamedSuccessor, VariableElement::Successor>;
95 
96 /// This class represents a variable that refers to a property argument.
97 using PropertyVariable =
98     OpVariableElement<NamedProperty, VariableElement::Property>;
99 
100 /// LLVM RTTI helper for attribute-like variables, that is, attributes or
101 /// properties. This allows for common handling of attributes and properties in
102 /// parts of the code that are oblivious to whether something is stored as an
103 /// attribute or a property.
104 struct AttributeLikeVariable : public VariableElement {
105   enum { AttributeLike = 1 << 0 };
106 
107   static bool classof(const VariableElement *ve) {
108     return ve->getKind() == VariableElement::Attribute ||
109            ve->getKind() == VariableElement::Property;
110   }
111 
112   static bool classof(const FormatElement *fe) {
113     return isa<VariableElement>(fe) && classof(cast<VariableElement>(fe));
114   }
115 
116   /// Returns true if the variable is a UnitAttr or a UnitProp.
117   bool isUnit() const {
118     if (const auto *attr = dyn_cast<AttributeVariable>(this))
119       return attr->getVar()->attr.getBaseAttr().getAttrDefName() == "UnitAttr";
120     if (const auto *prop = dyn_cast<PropertyVariable>(this)) {
121       StringRef baseDefName =
122           prop->getVar()->prop.getBaseProperty().getPropertyDefName();
123       // Note: remove the `UnitProperty` case once the deprecation period is
124       // over.
125       return baseDefName == "UnitProp" || baseDefName == "UnitProperty";
126     }
127     llvm_unreachable("Type that wasn't listed in classof()");
128   }
129 
130   StringRef getName() const {
131     if (const auto *attr = dyn_cast<AttributeVariable>(this))
132       return attr->getVar()->name;
133     if (const auto *prop = dyn_cast<PropertyVariable>(this))
134       return prop->getVar()->name;
135     llvm_unreachable("Type that wasn't listed in classof()");
136   }
137 };
138 } // namespace
139 
140 //===----------------------------------------------------------------------===//
141 // DirectiveElement
142 
143 namespace {
144 /// This class represents the `operands` directive. This directive represents
145 /// all of the operands of an operation.
146 using OperandsDirective = DirectiveElementBase<DirectiveElement::Operands>;
147 
148 /// This class represents the `results` directive. This directive represents
149 /// all of the results of an operation.
150 using ResultsDirective = DirectiveElementBase<DirectiveElement::Results>;
151 
152 /// This class represents the `regions` directive. This directive represents
153 /// all of the regions of an operation.
154 using RegionsDirective = DirectiveElementBase<DirectiveElement::Regions>;
155 
156 /// This class represents the `successors` directive. This directive represents
157 /// all of the successors of an operation.
158 using SuccessorsDirective = DirectiveElementBase<DirectiveElement::Successors>;
159 
160 /// This class represents the `attr-dict` directive. This directive represents
161 /// the attribute dictionary of the operation.
162 class AttrDictDirective
163     : public DirectiveElementBase<DirectiveElement::AttrDict> {
164 public:
165   explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {}
166 
167   /// Return whether the dictionary should be printed with the 'attributes'
168   /// keyword.
169   bool isWithKeyword() const { return withKeyword; }
170 
171 private:
172   /// If the dictionary should be printed with the 'attributes' keyword.
173   bool withKeyword;
174 };
175 
176 /// This class represents the `prop-dict` directive. This directive represents
177 /// the properties of the operation, expressed as a directionary.
178 class PropDictDirective
179     : public DirectiveElementBase<DirectiveElement::PropDict> {
180 public:
181   explicit PropDictDirective() = default;
182 };
183 
184 /// This class represents the `functional-type` directive. This directive takes
185 /// two arguments and formats them, respectively, as the inputs and results of a
186 /// FunctionType.
187 class FunctionalTypeDirective
188     : public DirectiveElementBase<DirectiveElement::FunctionalType> {
189 public:
190   FunctionalTypeDirective(FormatElement *inputs, FormatElement *results)
191       : inputs(inputs), results(results) {}
192 
193   FormatElement *getInputs() const { return inputs; }
194   FormatElement *getResults() const { return results; }
195 
196 private:
197   /// The input and result arguments.
198   FormatElement *inputs, *results;
199 };
200 
201 /// This class represents the `type` directive.
202 class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> {
203 public:
204   TypeDirective(FormatElement *arg) : arg(arg) {}
205 
206   FormatElement *getArg() const { return arg; }
207 
208   /// Indicate if this type is printed "qualified" (that is it is
209   /// prefixed with the `!dialect.mnemonic`).
210   bool shouldBeQualified() { return shouldBeQualifiedFlag; }
211   void setShouldBeQualified(bool qualified = true) {
212     shouldBeQualifiedFlag = qualified;
213   }
214 
215 private:
216   /// The argument that is used to format the directive.
217   FormatElement *arg;
218 
219   bool shouldBeQualifiedFlag = false;
220 };
221 
222 /// This class represents a group of order-independent optional clauses. Each
223 /// clause starts with a literal element and has a coressponding parsing
224 /// element. A parsing element is a continous sequence of format elements.
225 /// Each clause can appear 0 or 1 time.
226 class OIListElement : public DirectiveElementBase<DirectiveElement::OIList> {
227 public:
228   OIListElement(std::vector<FormatElement *> &&literalElements,
229                 std::vector<std::vector<FormatElement *>> &&parsingElements)
230       : literalElements(std::move(literalElements)),
231         parsingElements(std::move(parsingElements)) {}
232 
233   /// Returns a range to iterate over the LiteralElements.
234   auto getLiteralElements() const {
235     return llvm::map_range(literalElements, [](FormatElement *el) {
236       return cast<LiteralElement>(el);
237     });
238   }
239 
240   /// Returns a range to iterate over the parsing elements corresponding to the
241   /// clauses.
242   ArrayRef<std::vector<FormatElement *>> getParsingElements() const {
243     return parsingElements;
244   }
245 
246   /// Returns a range to iterate over tuples of parsing and literal elements.
247   auto getClauses() const {
248     return llvm::zip(getLiteralElements(), getParsingElements());
249   }
250 
251   /// If the parsing element is a single UnitAttr element, then it returns the
252   /// attribute variable. Otherwise, returns nullptr.
253   AttributeLikeVariable *
254   getUnitVariableParsingElement(ArrayRef<FormatElement *> pelement) {
255     if (pelement.size() == 1) {
256       auto *attrElem = dyn_cast<AttributeLikeVariable>(pelement[0]);
257       if (attrElem && attrElem->isUnit())
258         return attrElem;
259     }
260     return nullptr;
261   }
262 
263 private:
264   /// A vector of `LiteralElement` objects. Each element stores the keyword
265   /// for one case of oilist element. For example, an oilist element along with
266   /// the `literalElements` vector:
267   /// ```
268   ///  oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
269   ///  literalElements = { `keyword`, `otherKeyword` }
270   /// ```
271   std::vector<FormatElement *> literalElements;
272 
273   /// A vector of valid declarative assembly format vectors. Each object in
274   /// parsing elements is a vector of elements in assembly format syntax.
275   /// For example, an oilist element along with the parsingElements vector:
276   /// ```
277   ///  oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
278   ///  parsingElements = {
279   ///    { `=`, `(`, $arg0, `)` },
280   ///    { `<`, $arg1, `>` }
281   ///  }
282   /// ```
283   std::vector<std::vector<FormatElement *>> parsingElements;
284 };
285 } // namespace
286 
287 //===----------------------------------------------------------------------===//
288 // OperationFormat
289 //===----------------------------------------------------------------------===//
290 
291 namespace {
292 
293 using ConstArgument =
294     llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>;
295 
296 struct OperationFormat {
297   /// This class represents a specific resolver for an operand or result type.
298   class TypeResolution {
299   public:
300     TypeResolution() = default;
301 
302     /// Get the index into the buildable types for this type, or std::nullopt.
303     std::optional<int> getBuilderIdx() const { return builderIdx; }
304     void setBuilderIdx(int idx) { builderIdx = idx; }
305 
306     /// Get the variable this type is resolved to, or nullptr.
307     const NamedTypeConstraint *getVariable() const {
308       return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(resolver);
309     }
310     /// Get the attribute this type is resolved to, or nullptr.
311     const NamedAttribute *getAttribute() const {
312       return llvm::dyn_cast_if_present<const NamedAttribute *>(resolver);
313     }
314     /// Get the transformer for the type of the variable, or std::nullopt.
315     std::optional<StringRef> getVarTransformer() const {
316       return variableTransformer;
317     }
318     void setResolver(ConstArgument arg, std::optional<StringRef> transformer) {
319       resolver = arg;
320       variableTransformer = transformer;
321       assert(getVariable() || getAttribute());
322     }
323 
324   private:
325     /// If the type is resolved with a buildable type, this is the index into
326     /// 'buildableTypes' in the parent format.
327     std::optional<int> builderIdx;
328     /// If the type is resolved based upon another operand or result, this is
329     /// the variable or the attribute that this type is resolved to.
330     ConstArgument resolver;
331     /// If the type is resolved based upon another operand or result, this is
332     /// a transformer to apply to the variable when resolving.
333     std::optional<StringRef> variableTransformer;
334   };
335 
336   /// The context in which an element is generated.
337   enum class GenContext {
338     /// The element is generated at the top-level or with the same behaviour.
339     Normal,
340     /// The element is generated inside an optional group.
341     Optional
342   };
343 
344   OperationFormat(const Operator &op, bool hasProperties)
345       : useProperties(hasProperties), opCppClassName(op.getCppClassName()) {
346     operandTypes.resize(op.getNumOperands(), TypeResolution());
347     resultTypes.resize(op.getNumResults(), TypeResolution());
348 
349     hasImplicitTermTrait = llvm::any_of(op.getTraits(), [](const Trait &trait) {
350       return trait.getDef().isSubClassOf("SingleBlockImplicitTerminatorImpl");
351     });
352 
353     hasSingleBlockTrait = op.getTrait("::mlir::OpTrait::SingleBlock");
354   }
355 
356   /// Generate the operation parser from this format.
357   void genParser(Operator &op, OpClass &opClass);
358   /// Generate the parser code for a specific format element.
359   void genElementParser(FormatElement *element, MethodBody &body,
360                         FmtContext &attrTypeCtx,
361                         GenContext genCtx = GenContext::Normal);
362   /// Generate the C++ to resolve the types of operands and results during
363   /// parsing.
364   void genParserTypeResolution(Operator &op, MethodBody &body);
365   /// Generate the C++ to resolve the types of the operands during parsing.
366   void genParserOperandTypeResolution(
367       Operator &op, MethodBody &body,
368       function_ref<void(TypeResolution &, StringRef)> emitTypeResolver);
369   /// Generate the C++ to resolve regions during parsing.
370   void genParserRegionResolution(Operator &op, MethodBody &body);
371   /// Generate the C++ to resolve successors during parsing.
372   void genParserSuccessorResolution(Operator &op, MethodBody &body);
373   /// Generate the C++ to handling variadic segment size traits.
374   void genParserVariadicSegmentResolution(Operator &op, MethodBody &body);
375 
376   /// Generate the operation printer from this format.
377   void genPrinter(Operator &op, OpClass &opClass);
378 
379   /// Generate the printer code for a specific format element.
380   void genElementPrinter(FormatElement *element, MethodBody &body, Operator &op,
381                          bool &shouldEmitSpace, bool &lastWasPunctuation);
382 
383   /// The various elements in this format.
384   std::vector<FormatElement *> elements;
385 
386   /// A flag indicating if all operand/result types were seen. If the format
387   /// contains these, it can not contain individual type resolvers.
388   bool allOperands = false, allOperandTypes = false, allResultTypes = false;
389 
390   /// A flag indicating if this operation infers its result types
391   bool infersResultTypes = false;
392 
393   /// A flag indicating if this operation has the SingleBlockImplicitTerminator
394   /// trait.
395   bool hasImplicitTermTrait;
396 
397   /// A flag indicating if this operation has the SingleBlock trait.
398   bool hasSingleBlockTrait;
399 
400   /// Indicate whether we need to use properties for the current operator.
401   bool useProperties;
402 
403   /// Indicate whether prop-dict is used in the format
404   bool hasPropDict;
405 
406   /// The Operation class name
407   StringRef opCppClassName;
408 
409   /// A map of buildable types to indices.
410   llvm::MapVector<StringRef, int, StringMap<int>> buildableTypes;
411 
412   /// The index of the buildable type, if valid, for every operand and result.
413   std::vector<TypeResolution> operandTypes, resultTypes;
414 
415   /// The set of attributes explicitly used within the format.
416   llvm::SmallSetVector<const NamedAttribute *, 8> usedAttributes;
417   llvm::StringSet<> inferredAttributes;
418 
419   /// The set of properties explicitly used within the format.
420   llvm::SmallSetVector<const NamedProperty *, 8> usedProperties;
421 };
422 } // namespace
423 
424 //===----------------------------------------------------------------------===//
425 // Parser Gen
426 
427 /// Returns true if we can format the given attribute as an EnumAttr in the
428 /// parser format.
429 static bool canFormatEnumAttr(const NamedAttribute *attr) {
430   Attribute baseAttr = attr->attr.getBaseAttr();
431   const EnumAttr *enumAttr = dyn_cast<EnumAttr>(&baseAttr);
432   if (!enumAttr)
433     return false;
434 
435   // The attribute must have a valid underlying type and a constant builder.
436   return !enumAttr->getUnderlyingType().empty() &&
437          !enumAttr->getConstBuilderTemplate().empty();
438 }
439 
440 /// Returns if we should format the given attribute as an SymbolNameAttr.
441 static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
442   return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr";
443 }
444 
445 /// The code snippet used to generate a parser call for an attribute.
446 ///
447 /// {0}: The name of the attribute.
448 /// {1}: The type for the attribute.
449 const char *const attrParserCode = R"(
450   if (parser.parseCustomAttributeWithFallback({0}Attr, {1})) {{
451     return ::mlir::failure();
452   }
453 )";
454 
455 /// The code snippet used to generate a parser call for an attribute.
456 ///
457 /// {0}: The name of the attribute.
458 /// {1}: The type for the attribute.
459 const char *const genericAttrParserCode = R"(
460   if (parser.parseAttribute({0}Attr, {1}))
461     return ::mlir::failure();
462 )";
463 
464 const char *const optionalAttrParserCode = R"(
465   ::mlir::OptionalParseResult parseResult{0}Attr =
466     parser.parseOptionalAttribute({0}Attr, {1});
467   if (parseResult{0}Attr.has_value() && failed(*parseResult{0}Attr))
468     return ::mlir::failure();
469   if (parseResult{0}Attr.has_value() && succeeded(*parseResult{0}Attr))
470 )";
471 
472 /// The code snippet used to generate a parser call for a symbol name attribute.
473 ///
474 /// {0}: The name of the attribute.
475 const char *const symbolNameAttrParserCode = R"(
476   if (parser.parseSymbolName({0}Attr))
477     return ::mlir::failure();
478 )";
479 const char *const optionalSymbolNameAttrParserCode = R"(
480   // Parsing an optional symbol name doesn't fail, so no need to check the
481   // result.
482   (void)parser.parseOptionalSymbolName({0}Attr);
483 )";
484 
485 /// The code snippet used to generate a parser call for an enum attribute.
486 ///
487 /// {0}: The name of the attribute.
488 /// {1}: The c++ namespace for the enum symbolize functions.
489 /// {2}: The function to symbolize a string of the enum.
490 /// {3}: The constant builder call to create an attribute of the enum type.
491 /// {4}: The set of allowed enum keywords.
492 /// {5}: The error message on failure when the enum isn't present.
493 /// {6}: The attribute assignment expression
494 const char *const enumAttrParserCode = R"(
495   {
496     ::llvm::StringRef attrStr;
497     ::mlir::NamedAttrList attrStorage;
498     auto loc = parser.getCurrentLocation();
499     if (parser.parseOptionalKeyword(&attrStr, {4})) {
500       ::mlir::StringAttr attrVal;
501       ::mlir::OptionalParseResult parseResult =
502         parser.parseOptionalAttribute(attrVal,
503                                       parser.getBuilder().getNoneType(),
504                                       "{0}", attrStorage);
505       if (parseResult.has_value()) {{
506         if (failed(*parseResult))
507           return ::mlir::failure();
508         attrStr = attrVal.getValue();
509       } else {
510         {5}
511       }
512     }
513     if (!attrStr.empty()) {
514       auto attrOptional = {1}::{2}(attrStr);
515       if (!attrOptional)
516         return parser.emitError(loc, "invalid ")
517                << "{0} attribute specification: \"" << attrStr << '"';;
518 
519       {0}Attr = {3};
520       {6}
521     }
522   }
523 )";
524 
525 /// The code snippet used to generate a parser call for a property.
526 /// {0}: The name of the property
527 /// {1}: The C++ class name of the operation
528 /// {2}: The property's parser code with appropriate substitutions performed
529 /// {3}: The description of the expected property for the error message.
530 const char *const propertyParserCode = R"(
531   auto {0}PropLoc = parser.getCurrentLocation();
532   auto {0}PropParseResult = [&](auto& propStorage) -> ::mlir::ParseResult {{
533     {2}
534     return ::mlir::success();
535   }(result.getOrAddProperties<{1}::Properties>().{0});
536   if (failed({0}PropParseResult)) {{
537     return parser.emitError({0}PropLoc, "invalid value for property {0}, expected {3}");
538   }
539 )";
540 
541 /// The code snippet used to generate a parser call for a property.
542 /// {0}: The name of the property
543 /// {1}: The C++ class name of the operation
544 /// {2}: The property's parser code with appropriate substitutions performed
545 const char *const optionalPropertyParserCode = R"(
546   auto {0}PropParseResult = [&](auto& propStorage) -> ::mlir::OptionalParseResult {{
547     {2}
548     return ::mlir::success();
549   }(result.getOrAddProperties<{1}::Properties>().{0});
550   if ({0}PropParseResult.has_value() && failed(*{0}PropParseResult)) {{
551     return ::mlir::failure();
552   }
553 )";
554 
555 /// The code snippet used to generate a parser call for an operand.
556 ///
557 /// {0}: The name of the operand.
558 const char *const variadicOperandParserCode = R"(
559   {0}OperandsLoc = parser.getCurrentLocation();
560   if (parser.parseOperandList({0}Operands))
561     return ::mlir::failure();
562 )";
563 const char *const optionalOperandParserCode = R"(
564   {
565     {0}OperandsLoc = parser.getCurrentLocation();
566     ::mlir::OpAsmParser::UnresolvedOperand operand;
567     ::mlir::OptionalParseResult parseResult =
568                                     parser.parseOptionalOperand(operand);
569     if (parseResult.has_value()) {
570       if (failed(*parseResult))
571         return ::mlir::failure();
572       {0}Operands.push_back(operand);
573     }
574   }
575 )";
576 const char *const operandParserCode = R"(
577   {0}OperandsLoc = parser.getCurrentLocation();
578   if (parser.parseOperand({0}RawOperand))
579     return ::mlir::failure();
580 )";
581 /// The code snippet used to generate a parser call for a VariadicOfVariadic
582 /// operand.
583 ///
584 /// {0}: The name of the operand.
585 /// {1}: The name of segment size attribute.
586 const char *const variadicOfVariadicOperandParserCode = R"(
587   {
588     {0}OperandsLoc = parser.getCurrentLocation();
589     int32_t curSize = 0;
590     do {
591       if (parser.parseOptionalLParen())
592         break;
593       if (parser.parseOperandList({0}Operands) || parser.parseRParen())
594         return ::mlir::failure();
595       {0}OperandGroupSizes.push_back({0}Operands.size() - curSize);
596       curSize = {0}Operands.size();
597     } while (succeeded(parser.parseOptionalComma()));
598   }
599 )";
600 
601 /// The code snippet used to generate a parser call for a type list.
602 ///
603 /// {0}: The name for the type list.
604 const char *const variadicOfVariadicTypeParserCode = R"(
605   do {
606     if (parser.parseOptionalLParen())
607       break;
608     if (parser.parseOptionalRParen() &&
609         (parser.parseTypeList({0}Types) || parser.parseRParen()))
610       return ::mlir::failure();
611   } while (succeeded(parser.parseOptionalComma()));
612 )";
613 const char *const variadicTypeParserCode = R"(
614   if (parser.parseTypeList({0}Types))
615     return ::mlir::failure();
616 )";
617 const char *const optionalTypeParserCode = R"(
618   {
619     ::mlir::Type optionalType;
620     ::mlir::OptionalParseResult parseResult =
621                                     parser.parseOptionalType(optionalType);
622     if (parseResult.has_value()) {
623       if (failed(*parseResult))
624         return ::mlir::failure();
625       {0}Types.push_back(optionalType);
626     }
627   }
628 )";
629 const char *const typeParserCode = R"(
630   {
631     {0} type;
632     if (parser.parseCustomTypeWithFallback(type))
633       return ::mlir::failure();
634     {1}RawType = type;
635   }
636 )";
637 const char *const qualifiedTypeParserCode = R"(
638   if (parser.parseType({1}RawType))
639     return ::mlir::failure();
640 )";
641 
642 /// The code snippet used to generate a parser call for a functional type.
643 ///
644 /// {0}: The name for the input type list.
645 /// {1}: The name for the result type list.
646 const char *const functionalTypeParserCode = R"(
647   ::mlir::FunctionType {0}__{1}_functionType;
648   if (parser.parseType({0}__{1}_functionType))
649     return ::mlir::failure();
650   {0}Types = {0}__{1}_functionType.getInputs();
651   {1}Types = {0}__{1}_functionType.getResults();
652 )";
653 
654 /// The code snippet used to generate a parser call to infer return types.
655 ///
656 /// {0}: The operation class name
657 const char *const inferReturnTypesParserCode = R"(
658   ::llvm::SmallVector<::mlir::Type> inferredReturnTypes;
659   if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
660       result.location, result.operands,
661       result.attributes.getDictionary(parser.getContext()),
662       result.getRawProperties(),
663       result.regions, inferredReturnTypes)))
664     return ::mlir::failure();
665   result.addTypes(inferredReturnTypes);
666 )";
667 
668 /// The code snippet used to generate a parser call for a region list.
669 ///
670 /// {0}: The name for the region list.
671 const char *regionListParserCode = R"(
672   {
673     std::unique_ptr<::mlir::Region> region;
674     auto firstRegionResult = parser.parseOptionalRegion(region);
675     if (firstRegionResult.has_value()) {
676       if (failed(*firstRegionResult))
677         return ::mlir::failure();
678       {0}Regions.emplace_back(std::move(region));
679 
680       // Parse any trailing regions.
681       while (succeeded(parser.parseOptionalComma())) {
682         region = std::make_unique<::mlir::Region>();
683         if (parser.parseRegion(*region))
684           return ::mlir::failure();
685         {0}Regions.emplace_back(std::move(region));
686       }
687     }
688   }
689 )";
690 
691 /// The code snippet used to ensure a list of regions have terminators.
692 ///
693 /// {0}: The name of the region list.
694 const char *regionListEnsureTerminatorParserCode = R"(
695   for (auto &region : {0}Regions)
696     ensureTerminator(*region, parser.getBuilder(), result.location);
697 )";
698 
699 /// The code snippet used to ensure a list of regions have a block.
700 ///
701 /// {0}: The name of the region list.
702 const char *regionListEnsureSingleBlockParserCode = R"(
703   for (auto &region : {0}Regions)
704     if (region->empty()) region->emplaceBlock();
705 )";
706 
707 /// The code snippet used to generate a parser call for an optional region.
708 ///
709 /// {0}: The name of the region.
710 const char *optionalRegionParserCode = R"(
711   {
712      auto parseResult = parser.parseOptionalRegion(*{0}Region);
713      if (parseResult.has_value() && failed(*parseResult))
714        return ::mlir::failure();
715   }
716 )";
717 
718 /// The code snippet used to generate a parser call for a region.
719 ///
720 /// {0}: The name of the region.
721 const char *regionParserCode = R"(
722   if (parser.parseRegion(*{0}Region))
723     return ::mlir::failure();
724 )";
725 
726 /// The code snippet used to ensure a region has a terminator.
727 ///
728 /// {0}: The name of the region.
729 const char *regionEnsureTerminatorParserCode = R"(
730   ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
731 )";
732 
733 /// The code snippet used to ensure a region has a block.
734 ///
735 /// {0}: The name of the region.
736 const char *regionEnsureSingleBlockParserCode = R"(
737   if ({0}Region->empty()) {0}Region->emplaceBlock();
738 )";
739 
740 /// The code snippet used to generate a parser call for a successor list.
741 ///
742 /// {0}: The name for the successor list.
743 const char *successorListParserCode = R"(
744   {
745     ::mlir::Block *succ;
746     auto firstSucc = parser.parseOptionalSuccessor(succ);
747     if (firstSucc.has_value()) {
748       if (failed(*firstSucc))
749         return ::mlir::failure();
750       {0}Successors.emplace_back(succ);
751 
752       // Parse any trailing successors.
753       while (succeeded(parser.parseOptionalComma())) {
754         if (parser.parseSuccessor(succ))
755           return ::mlir::failure();
756         {0}Successors.emplace_back(succ);
757       }
758     }
759   }
760 )";
761 
762 /// The code snippet used to generate a parser call for a successor.
763 ///
764 /// {0}: The name of the successor.
765 const char *successorParserCode = R"(
766   if (parser.parseSuccessor({0}Successor))
767     return ::mlir::failure();
768 )";
769 
770 /// The code snippet used to generate a parser for OIList
771 ///
772 /// {0}: literal keyword corresponding to a case for oilist
773 const char *oilistParserCode = R"(
774   if ({0}Clause) {
775     return parser.emitError(parser.getNameLoc())
776           << "`{0}` clause can appear at most once in the expansion of the "
777              "oilist directive";
778   }
779   {0}Clause = true;
780 )";
781 
782 namespace {
783 /// The type of length for a given parse argument.
784 enum class ArgumentLengthKind {
785   /// The argument is a variadic of a variadic, and may contain 0->N range
786   /// elements.
787   VariadicOfVariadic,
788   /// The argument is variadic, and may contain 0->N elements.
789   Variadic,
790   /// The argument is optional, and may contain 0 or 1 elements.
791   Optional,
792   /// The argument is a single element, i.e. always represents 1 element.
793   Single
794 };
795 } // namespace
796 
797 /// Get the length kind for the given constraint.
798 static ArgumentLengthKind
799 getArgumentLengthKind(const NamedTypeConstraint *var) {
800   if (var->isOptional())
801     return ArgumentLengthKind::Optional;
802   if (var->isVariadicOfVariadic())
803     return ArgumentLengthKind::VariadicOfVariadic;
804   if (var->isVariadic())
805     return ArgumentLengthKind::Variadic;
806   return ArgumentLengthKind::Single;
807 }
808 
809 /// Get the name used for the type list for the given type directive operand.
810 /// 'lengthKind' to the corresponding kind for the given argument.
811 static StringRef getTypeListName(FormatElement *arg,
812                                  ArgumentLengthKind &lengthKind) {
813   if (auto *operand = dyn_cast<OperandVariable>(arg)) {
814     lengthKind = getArgumentLengthKind(operand->getVar());
815     return operand->getVar()->name;
816   }
817   if (auto *result = dyn_cast<ResultVariable>(arg)) {
818     lengthKind = getArgumentLengthKind(result->getVar());
819     return result->getVar()->name;
820   }
821   lengthKind = ArgumentLengthKind::Variadic;
822   if (isa<OperandsDirective>(arg))
823     return "allOperand";
824   if (isa<ResultsDirective>(arg))
825     return "allResult";
826   llvm_unreachable("unknown 'type' directive argument");
827 }
828 
829 /// Generate the parser for a literal value.
830 static void genLiteralParser(StringRef value, MethodBody &body) {
831   // Handle the case of a keyword/identifier.
832   if (value.front() == '_' || isalpha(value.front())) {
833     body << "Keyword(\"" << value << "\")";
834     return;
835   }
836   body << (StringRef)StringSwitch<StringRef>(value)
837               .Case("->", "Arrow()")
838               .Case(":", "Colon()")
839               .Case(",", "Comma()")
840               .Case("=", "Equal()")
841               .Case("<", "Less()")
842               .Case(">", "Greater()")
843               .Case("{", "LBrace()")
844               .Case("}", "RBrace()")
845               .Case("(", "LParen()")
846               .Case(")", "RParen()")
847               .Case("[", "LSquare()")
848               .Case("]", "RSquare()")
849               .Case("?", "Question()")
850               .Case("+", "Plus()")
851               .Case("*", "Star()")
852               .Case("...", "Ellipsis()");
853 }
854 
855 /// Generate the storage code required for parsing the given element.
856 static void genElementParserStorage(FormatElement *element, const Operator &op,
857                                     MethodBody &body) {
858   if (auto *optional = dyn_cast<OptionalElement>(element)) {
859     ArrayRef<FormatElement *> elements = optional->getThenElements();
860 
861     // If the anchor is a unit attribute, it won't be parsed directly so elide
862     // it.
863     auto *anchor = dyn_cast<AttributeLikeVariable>(optional->getAnchor());
864     FormatElement *elidedAnchorElement = nullptr;
865     if (anchor && anchor != elements.front() && anchor->isUnit())
866       elidedAnchorElement = anchor;
867     for (FormatElement *childElement : elements)
868       if (childElement != elidedAnchorElement)
869         genElementParserStorage(childElement, op, body);
870     for (FormatElement *childElement : optional->getElseElements())
871       genElementParserStorage(childElement, op, body);
872 
873   } else if (auto *oilist = dyn_cast<OIListElement>(element)) {
874     for (ArrayRef<FormatElement *> pelement : oilist->getParsingElements()) {
875       if (!oilist->getUnitVariableParsingElement(pelement))
876         for (FormatElement *element : pelement)
877           genElementParserStorage(element, op, body);
878     }
879 
880   } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
881     for (FormatElement *paramElement : custom->getArguments())
882       genElementParserStorage(paramElement, op, body);
883 
884   } else if (isa<OperandsDirective>(element)) {
885     body << "  ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
886             "allOperands;\n";
887 
888   } else if (isa<RegionsDirective>(element)) {
889     body << "  ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
890             "fullRegions;\n";
891 
892   } else if (isa<SuccessorsDirective>(element)) {
893     body << "  ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
894 
895   } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
896     const NamedAttribute *var = attr->getVar();
897     body << formatv("  {0} {1}Attr;\n", var->attr.getStorageType(), var->name);
898 
899   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
900     StringRef name = operand->getVar()->name;
901     if (operand->getVar()->isVariableLength()) {
902       body
903           << "  ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
904           << name << "Operands;\n";
905       if (operand->getVar()->isVariadicOfVariadic()) {
906         body << "    llvm::SmallVector<int32_t> " << name
907              << "OperandGroupSizes;\n";
908       }
909     } else {
910       body << "  ::mlir::OpAsmParser::UnresolvedOperand " << name
911            << "RawOperand{};\n"
912            << "  ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> "
913            << name << "Operands(&" << name << "RawOperand, 1);";
914     }
915     body << formatv("  ::llvm::SMLoc {0}OperandsLoc;\n"
916                     "  (void){0}OperandsLoc;\n",
917                     name);
918 
919   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
920     StringRef name = region->getVar()->name;
921     if (region->getVar()->isVariadic()) {
922       body << formatv(
923           "  ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
924           "{0}Regions;\n",
925           name);
926     } else {
927       body << formatv("  std::unique_ptr<::mlir::Region> {0}Region = "
928                       "std::make_unique<::mlir::Region>();\n",
929                       name);
930     }
931 
932   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
933     StringRef name = successor->getVar()->name;
934     if (successor->getVar()->isVariadic()) {
935       body << formatv("  ::llvm::SmallVector<::mlir::Block *, 2> "
936                       "{0}Successors;\n",
937                       name);
938     } else {
939       body << formatv("  ::mlir::Block *{0}Successor = nullptr;\n", name);
940     }
941 
942   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
943     ArgumentLengthKind lengthKind;
944     StringRef name = getTypeListName(dir->getArg(), lengthKind);
945     if (lengthKind != ArgumentLengthKind::Single)
946       body << "  ::llvm::SmallVector<::mlir::Type, 1> " << name << "Types;\n";
947     else
948       body
949           << formatv("  ::mlir::Type {0}RawType{{};\n", name)
950           << formatv(
951                  "  ::llvm::ArrayRef<::mlir::Type> {0}Types(&{0}RawType, 1);\n",
952                  name);
953   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
954     ArgumentLengthKind ignored;
955     body << "  ::llvm::ArrayRef<::mlir::Type> "
956          << getTypeListName(dir->getInputs(), ignored) << "Types;\n";
957     body << "  ::llvm::ArrayRef<::mlir::Type> "
958          << getTypeListName(dir->getResults(), ignored) << "Types;\n";
959   }
960 }
961 
962 /// Generate the parser for a parameter to a custom directive.
963 static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
964   if (auto *attr = dyn_cast<AttributeVariable>(param)) {
965     body << attr->getVar()->name << "Attr";
966   } else if (isa<AttrDictDirective>(param)) {
967     body << "result.attributes";
968   } else if (isa<PropDictDirective>(param)) {
969     body << "result";
970   } else if (auto *operand = dyn_cast<OperandVariable>(param)) {
971     StringRef name = operand->getVar()->name;
972     ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
973     if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
974       body << formatv("{0}OperandGroups", name);
975     else if (lengthKind == ArgumentLengthKind::Variadic)
976       body << formatv("{0}Operands", name);
977     else if (lengthKind == ArgumentLengthKind::Optional)
978       body << formatv("{0}Operand", name);
979     else
980       body << formatv("{0}RawOperand", name);
981 
982   } else if (auto *region = dyn_cast<RegionVariable>(param)) {
983     StringRef name = region->getVar()->name;
984     if (region->getVar()->isVariadic())
985       body << formatv("{0}Regions", name);
986     else
987       body << formatv("*{0}Region", name);
988 
989   } else if (auto *successor = dyn_cast<SuccessorVariable>(param)) {
990     StringRef name = successor->getVar()->name;
991     if (successor->getVar()->isVariadic())
992       body << formatv("{0}Successors", name);
993     else
994       body << formatv("{0}Successor", name);
995 
996   } else if (auto *dir = dyn_cast<RefDirective>(param)) {
997     genCustomParameterParser(dir->getArg(), body);
998 
999   } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
1000     ArgumentLengthKind lengthKind;
1001     StringRef listName = getTypeListName(dir->getArg(), lengthKind);
1002     if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
1003       body << formatv("{0}TypeGroups", listName);
1004     else if (lengthKind == ArgumentLengthKind::Variadic)
1005       body << formatv("{0}Types", listName);
1006     else if (lengthKind == ArgumentLengthKind::Optional)
1007       body << formatv("{0}Type", listName);
1008     else
1009       body << formatv("{0}RawType", listName);
1010 
1011   } else if (auto *string = dyn_cast<StringElement>(param)) {
1012     FmtContext ctx;
1013     ctx.withBuilder("parser.getBuilder()");
1014     ctx.addSubst("_ctxt", "parser.getContext()");
1015     body << tgfmt(string->getValue(), &ctx);
1016 
1017   } else if (auto *property = dyn_cast<PropertyVariable>(param)) {
1018     body << formatv("result.getOrAddProperties<Properties>().{0}",
1019                     property->getVar()->name);
1020   } else {
1021     llvm_unreachable("unknown custom directive parameter");
1022   }
1023 }
1024 
1025 /// Generate the parser for a custom directive.
1026 static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
1027                                      bool useProperties,
1028                                      StringRef opCppClassName,
1029                                      bool isOptional = false) {
1030   body << "  {\n";
1031 
1032   // Preprocess the directive variables.
1033   // * Add a local variable for optional operands and types. This provides a
1034   //   better API to the user defined parser methods.
1035   // * Set the location of operand variables.
1036   for (FormatElement *param : dir->getArguments()) {
1037     if (auto *operand = dyn_cast<OperandVariable>(param)) {
1038       auto *var = operand->getVar();
1039       body << "    " << var->name
1040            << "OperandsLoc = parser.getCurrentLocation();\n";
1041       if (var->isOptional()) {
1042         body << formatv(
1043             "    ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> "
1044             "{0}Operand;\n",
1045             var->name);
1046       } else if (var->isVariadicOfVariadic()) {
1047         body << formatv("    "
1048                         "::llvm::SmallVector<::llvm::SmallVector<::mlir::"
1049                         "OpAsmParser::UnresolvedOperand>> "
1050                         "{0}OperandGroups;\n",
1051                         var->name);
1052       }
1053     } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
1054       ArgumentLengthKind lengthKind;
1055       StringRef listName = getTypeListName(dir->getArg(), lengthKind);
1056       if (lengthKind == ArgumentLengthKind::Optional) {
1057         body << formatv("    ::mlir::Type {0}Type;\n", listName);
1058       } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
1059         body << formatv(
1060             "    ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
1061             "{0}TypeGroups;\n",
1062             listName);
1063       }
1064     } else if (auto *dir = dyn_cast<RefDirective>(param)) {
1065       FormatElement *input = dir->getArg();
1066       if (auto *operand = dyn_cast<OperandVariable>(input)) {
1067         if (!operand->getVar()->isOptional())
1068           continue;
1069         body << formatv(
1070             "    {0} {1}Operand = {1}Operands.empty() ? {0}() : "
1071             "{1}Operands[0];\n",
1072             "::std::optional<::mlir::OpAsmParser::UnresolvedOperand>",
1073             operand->getVar()->name);
1074 
1075       } else if (auto *type = dyn_cast<TypeDirective>(input)) {
1076         ArgumentLengthKind lengthKind;
1077         StringRef listName = getTypeListName(type->getArg(), lengthKind);
1078         if (lengthKind == ArgumentLengthKind::Optional) {
1079           body << formatv("    ::mlir::Type {0}Type = {0}Types.empty() ? "
1080                           "::mlir::Type() : {0}Types[0];\n",
1081                           listName);
1082         }
1083       }
1084     }
1085   }
1086 
1087   body << "    auto odsResult = parse" << dir->getName() << "(parser";
1088   for (FormatElement *param : dir->getArguments()) {
1089     body << ", ";
1090     genCustomParameterParser(param, body);
1091   }
1092   body << ");\n";
1093 
1094   if (isOptional) {
1095     body << "    if (!odsResult.has_value()) return {};\n"
1096          << "    if (::mlir::failed(*odsResult)) return ::mlir::failure();\n";
1097   } else {
1098     body << "    if (odsResult) return ::mlir::failure();\n";
1099   }
1100 
1101   // After parsing, add handling for any of the optional constructs.
1102   for (FormatElement *param : dir->getArguments()) {
1103     if (auto *attr = dyn_cast<AttributeVariable>(param)) {
1104       const NamedAttribute *var = attr->getVar();
1105       if (var->attr.isOptional() || var->attr.hasDefaultValue())
1106         body << formatv("    if ({0}Attr)\n  ", var->name);
1107       if (useProperties) {
1108         body << formatv(
1109             "    result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n",
1110             var->name, opCppClassName);
1111       } else {
1112         body << formatv("    result.addAttribute(\"{0}\", {0}Attr);\n",
1113                         var->name);
1114       }
1115     } else if (auto *operand = dyn_cast<OperandVariable>(param)) {
1116       const NamedTypeConstraint *var = operand->getVar();
1117       if (var->isOptional()) {
1118         body << formatv("    if ({0}Operand.has_value())\n"
1119                         "      {0}Operands.push_back(*{0}Operand);\n",
1120                         var->name);
1121       } else if (var->isVariadicOfVariadic()) {
1122         body << formatv(
1123             "    for (const auto &subRange : {0}OperandGroups) {{\n"
1124             "      {0}Operands.append(subRange.begin(), subRange.end());\n"
1125             "      {0}OperandGroupSizes.push_back(subRange.size());\n"
1126             "    }\n",
1127             var->name);
1128       }
1129     } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
1130       ArgumentLengthKind lengthKind;
1131       StringRef listName = getTypeListName(dir->getArg(), lengthKind);
1132       if (lengthKind == ArgumentLengthKind::Optional) {
1133         body << formatv("    if ({0}Type)\n"
1134                         "      {0}Types.push_back({0}Type);\n",
1135                         listName);
1136       } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
1137         body << formatv(
1138             "    for (const auto &subRange : {0}TypeGroups)\n"
1139             "      {0}Types.append(subRange.begin(), subRange.end());\n",
1140             listName);
1141       }
1142     }
1143   }
1144 
1145   body << "  }\n";
1146 }
1147 
1148 /// Generate the parser for a enum attribute.
1149 static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
1150                               FmtContext &attrTypeCtx, bool parseAsOptional,
1151                               bool useProperties, StringRef opCppClassName) {
1152   Attribute baseAttr = var->attr.getBaseAttr();
1153   const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1154   std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
1155 
1156   // Generate the code for building an attribute for this enum.
1157   std::string attrBuilderStr;
1158   {
1159     llvm::raw_string_ostream os(attrBuilderStr);
1160     os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx,
1161                 "*attrOptional");
1162   }
1163 
1164   // Build a string containing the cases that can be formatted as a keyword.
1165   std::string validCaseKeywordsStr = "{";
1166   llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr);
1167   for (const EnumAttrCase &attrCase : cases)
1168     if (canFormatStringAsKeyword(attrCase.getStr()))
1169       validCaseKeywordsOS << '"' << attrCase.getStr() << "\",";
1170   validCaseKeywordsOS.str().back() = '}';
1171 
1172   // If the attribute is not optional, build an error message for the missing
1173   // attribute.
1174   std::string errorMessage;
1175   if (!parseAsOptional) {
1176     llvm::raw_string_ostream errorMessageOS(errorMessage);
1177     errorMessageOS
1178         << "return parser.emitError(loc, \"expected string or "
1179            "keyword containing one of the following enum values for attribute '"
1180         << var->name << "' [";
1181     llvm::interleaveComma(cases, errorMessageOS, [&](const auto &attrCase) {
1182       errorMessageOS << attrCase.getStr();
1183     });
1184     errorMessageOS << "]\");";
1185   }
1186   std::string attrAssignment;
1187   if (useProperties) {
1188     attrAssignment =
1189         formatv("  "
1190                 "result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;",
1191                 var->name, opCppClassName);
1192   } else {
1193     attrAssignment =
1194         formatv("result.addAttribute(\"{0}\", {0}Attr);", var->name);
1195   }
1196 
1197   body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
1198                   enumAttr.getStringToSymbolFnName(), attrBuilderStr,
1199                   validCaseKeywordsStr, errorMessage, attrAssignment);
1200 }
1201 
1202 // Generate the parser for a property.
1203 static void genPropertyParser(PropertyVariable *propVar, MethodBody &body,
1204                               StringRef opCppClassName,
1205                               bool requireParse = true) {
1206   StringRef name = propVar->getVar()->name;
1207   const Property &prop = propVar->getVar()->prop;
1208   bool parseOptionally =
1209       prop.hasDefaultValue() && !requireParse && prop.hasOptionalParser();
1210   FmtContext fmtContext;
1211   fmtContext.addSubst("_parser", "parser");
1212   fmtContext.addSubst("_ctxt", "parser.getContext()");
1213   fmtContext.addSubst("_storage", "propStorage");
1214 
1215   if (parseOptionally) {
1216     body << formatv(optionalPropertyParserCode, name, opCppClassName,
1217                     tgfmt(prop.getOptionalParserCall(), &fmtContext));
1218   } else {
1219     body << formatv(propertyParserCode, name, opCppClassName,
1220                     tgfmt(prop.getParserCall(), &fmtContext),
1221                     prop.getSummary());
1222   }
1223 }
1224 
1225 // Generate the parser for an attribute.
1226 static void genAttrParser(AttributeVariable *attr, MethodBody &body,
1227                           FmtContext &attrTypeCtx, bool parseAsOptional,
1228                           bool useProperties, StringRef opCppClassName) {
1229   const NamedAttribute *var = attr->getVar();
1230 
1231   // Check to see if we can parse this as an enum attribute.
1232   if (canFormatEnumAttr(var))
1233     return genEnumAttrParser(var, body, attrTypeCtx, parseAsOptional,
1234                              useProperties, opCppClassName);
1235 
1236   // Check to see if we should parse this as a symbol name attribute.
1237   if (shouldFormatSymbolNameAttr(var)) {
1238     body << formatv(parseAsOptional ? optionalSymbolNameAttrParserCode
1239                                     : symbolNameAttrParserCode,
1240                     var->name);
1241   } else {
1242 
1243     // If this attribute has a buildable type, use that when parsing the
1244     // attribute.
1245     std::string attrTypeStr;
1246     if (std::optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
1247       llvm::raw_string_ostream os(attrTypeStr);
1248       os << tgfmt(*typeBuilder, &attrTypeCtx);
1249     } else {
1250       attrTypeStr = "::mlir::Type{}";
1251     }
1252     if (parseAsOptional) {
1253       body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
1254     } else {
1255       if (attr->shouldBeQualified() ||
1256           var->attr.getStorageType() == "::mlir::Attribute")
1257         body << formatv(genericAttrParserCode, var->name, attrTypeStr);
1258       else
1259         body << formatv(attrParserCode, var->name, attrTypeStr);
1260     }
1261   }
1262   if (useProperties) {
1263     body << formatv(
1264         "  if ({0}Attr) result.getOrAddProperties<{1}::Properties>().{0} = "
1265         "{0}Attr;\n",
1266         var->name, opCppClassName);
1267   } else {
1268     body << formatv(
1269         "  if ({0}Attr) result.attributes.append(\"{0}\", {0}Attr);\n",
1270         var->name);
1271   }
1272 }
1273 
1274 // Generates the 'setPropertiesFromParsedAttr' used to set properties from a
1275 // 'prop-dict' dictionary attr.
1276 static void genParsedAttrPropertiesSetter(OperationFormat &fmt, Operator &op,
1277                                           OpClass &opClass) {
1278   // Not required unless 'prop-dict' is present or we are not using properties.
1279   if (!fmt.hasPropDict || !fmt.useProperties)
1280     return;
1281 
1282   SmallVector<MethodParameter> paramList;
1283   paramList.emplace_back("Properties &", "prop");
1284   paramList.emplace_back("::mlir::Attribute", "attr");
1285   paramList.emplace_back("::llvm::function_ref<::mlir::InFlightDiagnostic()>",
1286                          "emitError");
1287 
1288   Method *method = opClass.addStaticMethod("::llvm::LogicalResult",
1289                                            "setPropertiesFromParsedAttr",
1290                                            std::move(paramList));
1291   MethodBody &body = method->body().indent();
1292 
1293   body << R"decl(
1294 ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr);
1295 if (!dict) {
1296   emitError() << "expected DictionaryAttr to set properties";
1297   return ::mlir::failure();
1298 }
1299 )decl";
1300 
1301   // {0}: fromAttribute call
1302   // {1}: property name
1303   // {2}: isRequired
1304   const char *propFromAttrFmt = R"decl(
1305 auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr,
1306          ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) -> ::mlir::LogicalResult {{
1307   {0};
1308 };
1309 auto attr = dict.get("{1}");
1310 if (!attr && {2}) {{
1311   emitError() << "expected key entry for {1} in DictionaryAttr to set "
1312              "Properties.";
1313   return ::mlir::failure();
1314 }
1315 if (attr && ::mlir::failed(setFromAttr(prop.{1}, attr, emitError)))
1316   return ::mlir::failure();
1317 )decl";
1318 
1319   // Generate the setter for any property not parsed elsewhere.
1320   for (const NamedProperty &namedProperty : op.getProperties()) {
1321     if (fmt.usedProperties.contains(&namedProperty))
1322       continue;
1323 
1324     auto scope = body.scope("{\n", "}\n", /*indent=*/true);
1325 
1326     StringRef name = namedProperty.name;
1327     const Property &prop = namedProperty.prop;
1328     bool isRequired = !prop.hasDefaultValue();
1329     FmtContext fctx;
1330     body << formatv(propFromAttrFmt,
1331                     tgfmt(prop.getConvertFromAttributeCall(),
1332                           &fctx.addSubst("_attr", "propAttr")
1333                                .addSubst("_storage", "propStorage")
1334                                .addSubst("_diag", "emitError")),
1335                     name, isRequired);
1336   }
1337 
1338   // Generate the setter for any attribute not parsed elsewhere.
1339   for (const NamedAttribute &namedAttr : op.getAttributes()) {
1340     if (fmt.usedAttributes.contains(&namedAttr))
1341       continue;
1342 
1343     const Attribute &attr = namedAttr.attr;
1344     // Derived attributes do not need to be parsed.
1345     if (attr.isDerivedAttr())
1346       continue;
1347 
1348     auto scope = body.scope("{\n", "}\n", /*indent=*/true);
1349 
1350     // If the attribute has a default value or is optional, it does not need to
1351     // be present in the parsed dictionary attribute.
1352     bool isRequired = !attr.isOptional() && !attr.hasDefaultValue();
1353     body << formatv(R"decl(
1354 auto &propStorage = prop.{0};
1355 auto attr = dict.get("{0}");
1356 if (attr || /*isRequired=*/{1}) {{
1357   if (!attr) {{
1358     emitError() << "expected key entry for {0} in DictionaryAttr to set "
1359                "Properties.";
1360     return ::mlir::failure();
1361   }
1362   auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr);
1363   if (convertedAttr) {{
1364     propStorage = convertedAttr;
1365   } else {{
1366     emitError() << "Invalid attribute `{0}` in property conversion: " << attr;
1367     return ::mlir::failure();
1368   }
1369 }
1370 )decl",
1371                     namedAttr.name, isRequired);
1372   }
1373   body << "return ::mlir::success();\n";
1374 }
1375 
1376 void OperationFormat::genParser(Operator &op, OpClass &opClass) {
1377   SmallVector<MethodParameter> paramList;
1378   paramList.emplace_back("::mlir::OpAsmParser &", "parser");
1379   paramList.emplace_back("::mlir::OperationState &", "result");
1380 
1381   auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse",
1382                                          std::move(paramList));
1383   auto &body = method->body();
1384 
1385   // Generate variables to store the operands and type within the format. This
1386   // allows for referencing these variables in the presence of optional
1387   // groupings.
1388   for (FormatElement *element : elements)
1389     genElementParserStorage(element, op, body);
1390 
1391   // A format context used when parsing attributes with buildable types.
1392   FmtContext attrTypeCtx;
1393   attrTypeCtx.withBuilder("parser.getBuilder()");
1394 
1395   // Generate parsers for each of the elements.
1396   for (FormatElement *element : elements)
1397     genElementParser(element, body, attrTypeCtx);
1398 
1399   // Generate the code to resolve the operand/result types and successors now
1400   // that they have been parsed.
1401   genParserRegionResolution(op, body);
1402   genParserSuccessorResolution(op, body);
1403   genParserVariadicSegmentResolution(op, body);
1404   genParserTypeResolution(op, body);
1405 
1406   body << "  return ::mlir::success();\n";
1407 
1408   genParsedAttrPropertiesSetter(*this, op, opClass);
1409 }
1410 
1411 void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
1412                                        FmtContext &attrTypeCtx,
1413                                        GenContext genCtx) {
1414   /// Optional Group.
1415   if (auto *optional = dyn_cast<OptionalElement>(element)) {
1416     auto genElementParsers = [&](FormatElement *firstElement,
1417                                  ArrayRef<FormatElement *> elements,
1418                                  bool thenGroup) {
1419       // If the anchor is a unit attribute, we don't need to print it. When
1420       // parsing, we will add this attribute if this group is present.
1421       FormatElement *elidedAnchorElement = nullptr;
1422       auto *anchorVar = dyn_cast<AttributeLikeVariable>(optional->getAnchor());
1423       if (anchorVar && anchorVar != firstElement && anchorVar->isUnit()) {
1424         elidedAnchorElement = anchorVar;
1425 
1426         if (!thenGroup == optional->isInverted()) {
1427           // Add the anchor unit attribute or property to the operation state
1428           // or set the property to true.
1429           if (isa<PropertyVariable>(anchorVar)) {
1430             body << formatv(
1431                 "    result.getOrAddProperties<{1}::Properties>().{0} = true;",
1432                 anchorVar->getName(), opCppClassName);
1433           } else if (useProperties) {
1434             body << formatv(
1435                 "    result.getOrAddProperties<{1}::Properties>().{0} = "
1436                 "parser.getBuilder().getUnitAttr();",
1437                 anchorVar->getName(), opCppClassName);
1438           } else {
1439             body << "    result.addAttribute(\"" << anchorVar->getName()
1440                  << "\", parser.getBuilder().getUnitAttr());\n";
1441           }
1442         }
1443       }
1444 
1445       // Generate the rest of the elements inside an optional group. Elements in
1446       // an optional group after the guard are parsed as required.
1447       for (FormatElement *childElement : elements)
1448         if (childElement != elidedAnchorElement)
1449           genElementParser(childElement, body, attrTypeCtx,
1450                            GenContext::Optional);
1451     };
1452 
1453     ArrayRef<FormatElement *> thenElements =
1454         optional->getThenElements(/*parseable=*/true);
1455 
1456     // Generate a special optional parser for the first element to gate the
1457     // parsing of the rest of the elements.
1458     FormatElement *firstElement = thenElements.front();
1459     if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
1460       genAttrParser(attrVar, body, attrTypeCtx, /*parseAsOptional=*/true,
1461                     useProperties, opCppClassName);
1462       body << "  if (" << attrVar->getVar()->name << "Attr) {\n";
1463     } else if (auto *propVar = dyn_cast<PropertyVariable>(firstElement)) {
1464       genPropertyParser(propVar, body, opCppClassName, /*requireParse=*/false);
1465       body << formatv("if ({0}PropParseResult.has_value() && "
1466                       "succeeded(*{0}PropParseResult)) ",
1467                       propVar->getVar()->name)
1468            << " {\n";
1469     } else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
1470       body << "  if (::mlir::succeeded(parser.parseOptional";
1471       genLiteralParser(literal->getSpelling(), body);
1472       body << ")) {\n";
1473     } else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
1474       genElementParser(opVar, body, attrTypeCtx);
1475       body << "  if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
1476     } else if (auto *regionVar = dyn_cast<RegionVariable>(firstElement)) {
1477       const NamedRegion *region = regionVar->getVar();
1478       if (region->isVariadic()) {
1479         genElementParser(regionVar, body, attrTypeCtx);
1480         body << "  if (!" << region->name << "Regions.empty()) {\n";
1481       } else {
1482         body << formatv(optionalRegionParserCode, region->name);
1483         body << "  if (!" << region->name << "Region->empty()) {\n  ";
1484         if (hasImplicitTermTrait)
1485           body << formatv(regionEnsureTerminatorParserCode, region->name);
1486         else if (hasSingleBlockTrait)
1487           body << formatv(regionEnsureSingleBlockParserCode, region->name);
1488       }
1489     } else if (auto *custom = dyn_cast<CustomDirective>(firstElement)) {
1490       body << "  if (auto optResult = [&]() -> ::mlir::OptionalParseResult {\n";
1491       genCustomDirectiveParser(custom, body, useProperties, opCppClassName,
1492                                /*isOptional=*/true);
1493       body << "    return ::mlir::success();\n"
1494            << "  }(); optResult.has_value() && ::mlir::failed(*optResult)) {\n"
1495            << "    return ::mlir::failure();\n"
1496            << "  } else if (optResult.has_value()) {\n";
1497     }
1498 
1499     genElementParsers(firstElement, thenElements.drop_front(),
1500                       /*thenGroup=*/true);
1501     body << "  }";
1502 
1503     // Generate the else elements.
1504     auto elseElements = optional->getElseElements();
1505     if (!elseElements.empty()) {
1506       body << " else {\n";
1507       ArrayRef<FormatElement *> elseElements =
1508           optional->getElseElements(/*parseable=*/true);
1509       genElementParsers(elseElements.front(), elseElements,
1510                         /*thenGroup=*/false);
1511       body << "  }";
1512     }
1513     body << "\n";
1514 
1515     /// OIList Directive
1516   } else if (OIListElement *oilist = dyn_cast<OIListElement>(element)) {
1517     for (LiteralElement *le : oilist->getLiteralElements())
1518       body << "  bool " << le->getSpelling() << "Clause = false;\n";
1519 
1520     // Generate the parsing loop
1521     body << "  while(true) {\n";
1522     for (auto clause : oilist->getClauses()) {
1523       LiteralElement *lelement = std::get<0>(clause);
1524       ArrayRef<FormatElement *> pelement = std::get<1>(clause);
1525       body << "if (succeeded(parser.parseOptional";
1526       genLiteralParser(lelement->getSpelling(), body);
1527       body << ")) {\n";
1528       StringRef lelementName = lelement->getSpelling();
1529       body << formatv(oilistParserCode, lelementName);
1530       if (AttributeLikeVariable *unitVarElem =
1531               oilist->getUnitVariableParsingElement(pelement)) {
1532         if (isa<PropertyVariable>(unitVarElem)) {
1533           body << formatv(
1534               "    result.getOrAddProperties<{1}::Properties>().{0} = true;",
1535               unitVarElem->getName(), opCppClassName);
1536         } else if (useProperties) {
1537           body << formatv(
1538               "    result.getOrAddProperties<{1}::Properties>().{0} = "
1539               "parser.getBuilder().getUnitAttr();",
1540               unitVarElem->getName(), opCppClassName);
1541         } else {
1542           body << "  result.addAttribute(\"" << unitVarElem->getName()
1543                << "\", UnitAttr::get(parser.getContext()));\n";
1544         }
1545       } else {
1546         for (FormatElement *el : pelement)
1547           genElementParser(el, body, attrTypeCtx);
1548       }
1549       body << "    } else ";
1550     }
1551     body << " {\n";
1552     body << "    break;\n";
1553     body << "  }\n";
1554     body << "}\n";
1555 
1556     /// Literals.
1557   } else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) {
1558     body << "  if (parser.parse";
1559     genLiteralParser(literal->getSpelling(), body);
1560     body << ")\n    return ::mlir::failure();\n";
1561 
1562     /// Whitespaces.
1563   } else if (isa<WhitespaceElement>(element)) {
1564     // Nothing to parse.
1565 
1566     /// Arguments.
1567   } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1568     bool parseAsOptional =
1569         (genCtx == GenContext::Normal && attr->getVar()->attr.isOptional());
1570     genAttrParser(attr, body, attrTypeCtx, parseAsOptional, useProperties,
1571                   opCppClassName);
1572   } else if (auto *prop = dyn_cast<PropertyVariable>(element)) {
1573     genPropertyParser(prop, body, opCppClassName);
1574 
1575   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1576     ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
1577     StringRef name = operand->getVar()->name;
1578     if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
1579       body << formatv(variadicOfVariadicOperandParserCode, name);
1580     else if (lengthKind == ArgumentLengthKind::Variadic)
1581       body << formatv(variadicOperandParserCode, name);
1582     else if (lengthKind == ArgumentLengthKind::Optional)
1583       body << formatv(optionalOperandParserCode, name);
1584     else
1585       body << formatv(operandParserCode, name);
1586 
1587   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1588     bool isVariadic = region->getVar()->isVariadic();
1589     body << formatv(isVariadic ? regionListParserCode : regionParserCode,
1590                     region->getVar()->name);
1591     if (hasImplicitTermTrait)
1592       body << formatv(isVariadic ? regionListEnsureTerminatorParserCode
1593                                  : regionEnsureTerminatorParserCode,
1594                       region->getVar()->name);
1595     else if (hasSingleBlockTrait)
1596       body << formatv(isVariadic ? regionListEnsureSingleBlockParserCode
1597                                  : regionEnsureSingleBlockParserCode,
1598                       region->getVar()->name);
1599 
1600   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1601     bool isVariadic = successor->getVar()->isVariadic();
1602     body << formatv(isVariadic ? successorListParserCode : successorParserCode,
1603                     successor->getVar()->name);
1604 
1605     /// Directives.
1606   } else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
1607     body.indent() << "{\n";
1608     body.indent() << "auto loc = parser.getCurrentLocation();(void)loc;\n"
1609                   << "if (parser.parseOptionalAttrDict"
1610                   << (attrDict->isWithKeyword() ? "WithKeyword" : "")
1611                   << "(result.attributes))\n"
1612                   << "  return ::mlir::failure();\n";
1613     if (useProperties) {
1614       body << "if (failed(verifyInherentAttrs(result.name, result.attributes, "
1615               "[&]() {\n"
1616            << "    return parser.emitError(loc) << \"'\" << "
1617               "result.name.getStringRef() << \"' op \";\n"
1618            << "  })))\n"
1619            << "  return ::mlir::failure();\n";
1620     }
1621     body.unindent() << "}\n";
1622     body.unindent();
1623   } else if (isa<PropDictDirective>(element)) {
1624     if (useProperties) {
1625       body << "  if (parseProperties(parser, result))\n"
1626            << "    return ::mlir::failure();\n";
1627     }
1628   } else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
1629     genCustomDirectiveParser(customDir, body, useProperties, opCppClassName);
1630   } else if (isa<OperandsDirective>(element)) {
1631     body << "  [[maybe_unused]] ::llvm::SMLoc allOperandLoc ="
1632          << " parser.getCurrentLocation();\n"
1633          << "  if (parser.parseOperandList(allOperands))\n"
1634          << "    return ::mlir::failure();\n";
1635 
1636   } else if (isa<RegionsDirective>(element)) {
1637     body << formatv(regionListParserCode, "full");
1638     if (hasImplicitTermTrait)
1639       body << formatv(regionListEnsureTerminatorParserCode, "full");
1640     else if (hasSingleBlockTrait)
1641       body << formatv(regionListEnsureSingleBlockParserCode, "full");
1642 
1643   } else if (isa<SuccessorsDirective>(element)) {
1644     body << formatv(successorListParserCode, "full");
1645 
1646   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1647     ArgumentLengthKind lengthKind;
1648     StringRef listName = getTypeListName(dir->getArg(), lengthKind);
1649     if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
1650       body << formatv(variadicOfVariadicTypeParserCode, listName);
1651     } else if (lengthKind == ArgumentLengthKind::Variadic) {
1652       body << formatv(variadicTypeParserCode, listName);
1653     } else if (lengthKind == ArgumentLengthKind::Optional) {
1654       body << formatv(optionalTypeParserCode, listName);
1655     } else {
1656       const char *parserCode =
1657           dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode;
1658       TypeSwitch<FormatElement *>(dir->getArg())
1659           .Case<OperandVariable, ResultVariable>([&](auto operand) {
1660             body << formatv(false, parserCode,
1661                             operand->getVar()->constraint.getCppType(),
1662                             listName);
1663           })
1664           .Default([&](auto operand) {
1665             body << formatv(false, parserCode, "::mlir::Type", listName);
1666           });
1667     }
1668   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
1669     ArgumentLengthKind ignored;
1670     body << formatv(functionalTypeParserCode,
1671                     getTypeListName(dir->getInputs(), ignored),
1672                     getTypeListName(dir->getResults(), ignored));
1673   } else {
1674     llvm_unreachable("unknown format element");
1675   }
1676 }
1677 
1678 void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
1679   // If any of type resolutions use transformed variables, make sure that the
1680   // types of those variables are resolved.
1681   SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
1682   FmtContext verifierFCtx;
1683   for (TypeResolution &resolver :
1684        llvm::concat<TypeResolution>(resultTypes, operandTypes)) {
1685     std::optional<StringRef> transformer = resolver.getVarTransformer();
1686     if (!transformer)
1687       continue;
1688     // Ensure that we don't verify the same variables twice.
1689     const NamedTypeConstraint *variable = resolver.getVariable();
1690     if (!variable || !verifiedVariables.insert(variable).second)
1691       continue;
1692 
1693     auto constraint = variable->constraint;
1694     body << "  for (::mlir::Type type : " << variable->name << "Types) {\n"
1695          << "    (void)type;\n"
1696          << "    if (!("
1697          << tgfmt(constraint.getConditionTemplate(),
1698                   &verifierFCtx.withSelf("type"))
1699          << ")) {\n"
1700          << formatv("      return parser.emitError(parser.getNameLoc()) << "
1701                     "\"'{0}' must be {1}, but got \" << type;\n",
1702                     variable->name, constraint.getSummary())
1703          << "    }\n"
1704          << "  }\n";
1705   }
1706 
1707   // Initialize the set of buildable types.
1708   if (!buildableTypes.empty()) {
1709     FmtContext typeBuilderCtx;
1710     typeBuilderCtx.withBuilder("parser.getBuilder()");
1711     for (auto &it : buildableTypes)
1712       body << "  ::mlir::Type odsBuildableType" << it.second << " = "
1713            << tgfmt(it.first, &typeBuilderCtx) << ";\n";
1714   }
1715 
1716   // Emit the code necessary for a type resolver.
1717   auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
1718     if (std::optional<int> val = resolver.getBuilderIdx()) {
1719       body << "odsBuildableType" << *val;
1720     } else if (const NamedTypeConstraint *var = resolver.getVariable()) {
1721       if (std::optional<StringRef> tform = resolver.getVarTransformer()) {
1722         FmtContext fmtContext;
1723         fmtContext.addSubst("_ctxt", "parser.getContext()");
1724         if (var->isVariadic())
1725           fmtContext.withSelf(var->name + "Types");
1726         else
1727           fmtContext.withSelf(var->name + "Types[0]");
1728         body << tgfmt(*tform, &fmtContext);
1729       } else {
1730         body << var->name << "Types";
1731         if (!var->isVariadic())
1732           body << "[0]";
1733       }
1734     } else if (const NamedAttribute *attr = resolver.getAttribute()) {
1735       if (std::optional<StringRef> tform = resolver.getVarTransformer())
1736         body << tgfmt(*tform,
1737                       &FmtContext().withSelf(attr->name + "Attr.getType()"));
1738       else
1739         body << attr->name << "Attr.getType()";
1740     } else {
1741       body << curVar << "Types";
1742     }
1743   };
1744 
1745   // Resolve each of the result types.
1746   if (!infersResultTypes) {
1747     if (allResultTypes) {
1748       body << "  result.addTypes(allResultTypes);\n";
1749     } else {
1750       for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
1751         body << "  result.addTypes(";
1752         emitTypeResolver(resultTypes[i], op.getResultName(i));
1753         body << ");\n";
1754       }
1755     }
1756   }
1757 
1758   // Emit the operand type resolutions.
1759   genParserOperandTypeResolution(op, body, emitTypeResolver);
1760 
1761   // Handle return type inference once all operands have been resolved
1762   if (infersResultTypes)
1763     body << formatv(inferReturnTypesParserCode, op.getCppClassName());
1764 }
1765 
1766 void OperationFormat::genParserOperandTypeResolution(
1767     Operator &op, MethodBody &body,
1768     function_ref<void(TypeResolution &, StringRef)> emitTypeResolver) {
1769   // Early exit if there are no operands.
1770   if (op.getNumOperands() == 0)
1771     return;
1772 
1773   // Handle the case where all operand types are grouped together with
1774   // "types(operands)".
1775   if (allOperandTypes) {
1776     // If `operands` was specified, use the full operand list directly.
1777     if (allOperands) {
1778       body << "  if (parser.resolveOperands(allOperands, allOperandTypes, "
1779               "allOperandLoc, result.operands))\n"
1780               "    return ::mlir::failure();\n";
1781       return;
1782     }
1783 
1784     // Otherwise, use llvm::concat to merge the disjoint operand lists together.
1785     // llvm::concat does not allow the case of a single range, so guard it here.
1786     body << "  if (parser.resolveOperands(";
1787     if (op.getNumOperands() > 1) {
1788       body << "::llvm::concat<const ::mlir::OpAsmParser::UnresolvedOperand>(";
1789       llvm::interleaveComma(op.getOperands(), body, [&](auto &operand) {
1790         body << operand.name << "Operands";
1791       });
1792       body << ")";
1793     } else {
1794       body << op.operand_begin()->name << "Operands";
1795     }
1796     body << ", allOperandTypes, parser.getNameLoc(), result.operands))\n"
1797          << "    return ::mlir::failure();\n";
1798     return;
1799   }
1800 
1801   // Handle the case where all operands are grouped together with "operands".
1802   if (allOperands) {
1803     body << "  if (parser.resolveOperands(allOperands, ";
1804 
1805     // Group all of the operand types together to perform the resolution all at
1806     // once. Use llvm::concat to perform the merge. llvm::concat does not allow
1807     // the case of a single range, so guard it here.
1808     if (op.getNumOperands() > 1) {
1809       body << "::llvm::concat<const ::mlir::Type>(";
1810       llvm::interleaveComma(
1811           llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1812             body << "::llvm::ArrayRef<::mlir::Type>(";
1813             emitTypeResolver(operandTypes[i], op.getOperand(i).name);
1814             body << ")";
1815           });
1816       body << ")";
1817     } else {
1818       emitTypeResolver(operandTypes.front(), op.getOperand(0).name);
1819     }
1820 
1821     body << ", allOperandLoc, result.operands))\n    return "
1822             "::mlir::failure();\n";
1823     return;
1824   }
1825 
1826   // The final case is the one where each of the operands types are resolved
1827   // separately.
1828   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
1829     NamedTypeConstraint &operand = op.getOperand(i);
1830     body << "  if (parser.resolveOperands(" << operand.name << "Operands, ";
1831 
1832     // Resolve the type of this operand.
1833     TypeResolution &operandType = operandTypes[i];
1834     emitTypeResolver(operandType, operand.name);
1835 
1836     body << ", " << operand.name
1837          << "OperandsLoc, result.operands))\n    return ::mlir::failure();\n";
1838   }
1839 }
1840 
1841 void OperationFormat::genParserRegionResolution(Operator &op,
1842                                                 MethodBody &body) {
1843   // Check for the case where all regions were parsed.
1844   bool hasAllRegions = llvm::any_of(
1845       elements, [](FormatElement *elt) { return isa<RegionsDirective>(elt); });
1846   if (hasAllRegions) {
1847     body << "  result.addRegions(fullRegions);\n";
1848     return;
1849   }
1850 
1851   // Otherwise, handle each region individually.
1852   for (const NamedRegion &region : op.getRegions()) {
1853     if (region.isVariadic())
1854       body << "  result.addRegions(" << region.name << "Regions);\n";
1855     else
1856       body << "  result.addRegion(std::move(" << region.name << "Region));\n";
1857   }
1858 }
1859 
1860 void OperationFormat::genParserSuccessorResolution(Operator &op,
1861                                                    MethodBody &body) {
1862   // Check for the case where all successors were parsed.
1863   bool hasAllSuccessors = llvm::any_of(elements, [](FormatElement *elt) {
1864     return isa<SuccessorsDirective>(elt);
1865   });
1866   if (hasAllSuccessors) {
1867     body << "  result.addSuccessors(fullSuccessors);\n";
1868     return;
1869   }
1870 
1871   // Otherwise, handle each successor individually.
1872   for (const NamedSuccessor &successor : op.getSuccessors()) {
1873     if (successor.isVariadic())
1874       body << "  result.addSuccessors(" << successor.name << "Successors);\n";
1875     else
1876       body << "  result.addSuccessors(" << successor.name << "Successor);\n";
1877   }
1878 }
1879 
1880 void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
1881                                                          MethodBody &body) {
1882   if (!allOperands) {
1883     if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1884       auto interleaveFn = [&](const NamedTypeConstraint &operand) {
1885         // If the operand is variadic emit the parsed size.
1886         if (operand.isVariableLength())
1887           body << "static_cast<int32_t>(" << operand.name << "Operands.size())";
1888         else
1889           body << "1";
1890       };
1891       if (op.getDialect().usePropertiesForAttributes()) {
1892         body << "::llvm::copy(::llvm::ArrayRef<int32_t>({";
1893         llvm::interleaveComma(op.getOperands(), body, interleaveFn);
1894         body << formatv("}), "
1895                         "result.getOrAddProperties<{0}::Properties>()."
1896                         "operandSegmentSizes.begin());\n",
1897                         op.getCppClassName());
1898       } else {
1899         body << "  result.addAttribute(\"operandSegmentSizes\", "
1900              << "parser.getBuilder().getDenseI32ArrayAttr({";
1901         llvm::interleaveComma(op.getOperands(), body, interleaveFn);
1902         body << "}));\n";
1903       }
1904     }
1905     for (const NamedTypeConstraint &operand : op.getOperands()) {
1906       if (!operand.isVariadicOfVariadic())
1907         continue;
1908       if (op.getDialect().usePropertiesForAttributes()) {
1909         body << formatv(
1910             "  result.getOrAddProperties<{0}::Properties>().{1} = "
1911             "parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n",
1912             op.getCppClassName(),
1913             operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
1914             operand.name);
1915       } else {
1916         body << formatv(
1917             "  result.addAttribute(\"{0}\", "
1918             "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));"
1919             "\n",
1920             operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
1921             operand.name);
1922       }
1923     }
1924   }
1925 
1926   if (!allResultTypes &&
1927       op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
1928     auto interleaveFn = [&](const NamedTypeConstraint &result) {
1929       // If the result is variadic emit the parsed size.
1930       if (result.isVariableLength())
1931         body << "static_cast<int32_t>(" << result.name << "Types.size())";
1932       else
1933         body << "1";
1934     };
1935     if (op.getDialect().usePropertiesForAttributes()) {
1936       body << "::llvm::copy(::llvm::ArrayRef<int32_t>({";
1937       llvm::interleaveComma(op.getResults(), body, interleaveFn);
1938       body << formatv("}), "
1939                       "result.getOrAddProperties<{0}::Properties>()."
1940                       "resultSegmentSizes.begin());\n",
1941                       op.getCppClassName());
1942     } else {
1943       body << "  result.addAttribute(\"resultSegmentSizes\", "
1944            << "parser.getBuilder().getDenseI32ArrayAttr({";
1945       llvm::interleaveComma(op.getResults(), body, interleaveFn);
1946       body << "}));\n";
1947     }
1948   }
1949 }
1950 
1951 //===----------------------------------------------------------------------===//
1952 // PrinterGen
1953 
1954 /// The code snippet used to generate a printer call for a region of an
1955 // operation that has the SingleBlockImplicitTerminator trait.
1956 ///
1957 /// {0}: The name of the region.
1958 const char *regionSingleBlockImplicitTerminatorPrinterCode = R"(
1959   {
1960     bool printTerminator = true;
1961     if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
1962       printTerminator = !term->getAttrDictionary().empty() ||
1963                         term->getNumOperands() != 0 ||
1964                         term->getNumResults() != 0;
1965     }
1966     _odsPrinter.printRegion({0}, /*printEntryBlockArgs=*/true,
1967       /*printBlockTerminators=*/printTerminator);
1968   }
1969 )";
1970 
1971 /// The code snippet used to generate a printer call for an enum that has cases
1972 /// that can't be represented with a keyword.
1973 ///
1974 /// {0}: The name of the enum attribute.
1975 /// {1}: The name of the enum attributes symbolToString function.
1976 const char *enumAttrBeginPrinterCode = R"(
1977   {
1978     auto caseValue = {0}();
1979     auto caseValueStr = {1}(caseValue);
1980 )";
1981 
1982 /// Generate a check that an optional or default-valued attribute or property
1983 /// has a non-default value. For these purposes, the default value of an
1984 /// optional attribute is its presence, even if the attribute itself has a
1985 /// default value.
1986 static void genNonDefaultValueCheck(MethodBody &body, const Operator &op,
1987                                     AttributeVariable &attrElement) {
1988   Attribute attr = attrElement.getVar()->attr;
1989   std::string getter = op.getGetterName(attrElement.getVar()->name);
1990   bool optionalAndDefault = attr.isOptional() && attr.hasDefaultValue();
1991   if (optionalAndDefault)
1992     body << "(";
1993   if (attr.isOptional())
1994     body << getter << "Attr()";
1995   if (optionalAndDefault)
1996     body << " && ";
1997   if (attr.hasDefaultValue()) {
1998     FmtContext fctx;
1999     fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())");
2000     body << getter << "Attr() != "
2001          << tgfmt(attr.getConstBuilderTemplate(), &fctx,
2002                   attr.getDefaultValue());
2003   }
2004   if (optionalAndDefault)
2005     body << ")";
2006 }
2007 
2008 static void genNonDefaultValueCheck(MethodBody &body, const Operator &op,
2009                                     PropertyVariable &propElement) {
2010   body << op.getGetterName(propElement.getVar()->name)
2011        << "() != " << propElement.getVar()->prop.getDefaultValue();
2012 }
2013 
2014 /// Elide the variadic segment size attributes if necessary.
2015 /// This pushes elided attribute names in `elidedStorage`.
2016 static void genVariadicSegmentElision(OperationFormat &fmt, Operator &op,
2017                                       MethodBody &body,
2018                                       const char *elidedStorage) {
2019   if (!fmt.allOperands &&
2020       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
2021     body << "  " << elidedStorage << ".push_back(\"operandSegmentSizes\");\n";
2022   if (!fmt.allResultTypes &&
2023       op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
2024     body << "  " << elidedStorage << ".push_back(\"resultSegmentSizes\");\n";
2025 }
2026 
2027 /// Generate the printer for the 'prop-dict' directive.
2028 static void genPropDictPrinter(OperationFormat &fmt, Operator &op,
2029                                MethodBody &body) {
2030   body << "  ::llvm::SmallVector<::llvm::StringRef, 2> elidedProps;\n";
2031 
2032   genVariadicSegmentElision(fmt, op, body, "elidedProps");
2033 
2034   for (const NamedProperty *namedProperty : fmt.usedProperties)
2035     body << "  elidedProps.push_back(\"" << namedProperty->name << "\");\n";
2036   for (const NamedAttribute *namedAttr : fmt.usedAttributes)
2037     body << "  elidedProps.push_back(\"" << namedAttr->name << "\");\n";
2038 
2039   // Add code to check attributes for equality with their default values.
2040   // Default-valued attributes will not be printed when their value matches the
2041   // default.
2042   for (const NamedAttribute &namedAttr : op.getAttributes()) {
2043     const Attribute &attr = namedAttr.attr;
2044     if (!attr.isDerivedAttr() && attr.hasDefaultValue()) {
2045       const StringRef &name = namedAttr.name;
2046       FmtContext fctx;
2047       fctx.withBuilder("odsBuilder");
2048       std::string defaultValue = std::string(
2049           tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
2050       body << "  {\n";
2051       body << "     ::mlir::Builder odsBuilder(getContext());\n";
2052       body << "     ::mlir::Attribute attr = " << op.getGetterName(name)
2053            << "Attr();\n";
2054       body << "     if(attr && (attr == " << defaultValue << "))\n";
2055       body << "       elidedProps.push_back(\"" << name << "\");\n";
2056       body << "  }\n";
2057     }
2058   }
2059   // Similarly, elide default-valued properties.
2060   for (const NamedProperty &prop : op.getProperties()) {
2061     if (prop.prop.hasDefaultValue()) {
2062       body << "  if (" << op.getGetterName(prop.name)
2063            << "() == " << prop.prop.getDefaultValue() << ") {";
2064       body << "    elidedProps.push_back(\"" << prop.name << "\");\n";
2065       body << "  }\n";
2066     }
2067   }
2068 
2069   if (fmt.useProperties) {
2070     body << "  _odsPrinter << \" \";\n"
2071          << "  printProperties(this->getContext(), _odsPrinter, "
2072             "getProperties(), elidedProps);\n";
2073   }
2074 }
2075 
2076 /// Generate the printer for the 'attr-dict' directive.
2077 static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
2078                                MethodBody &body, bool withKeyword) {
2079   body << "  ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;\n";
2080 
2081   genVariadicSegmentElision(fmt, op, body, "elidedAttrs");
2082 
2083   for (const StringRef key : fmt.inferredAttributes.keys())
2084     body << "  elidedAttrs.push_back(\"" << key << "\");\n";
2085   for (const NamedAttribute *attr : fmt.usedAttributes)
2086     body << "  elidedAttrs.push_back(\"" << attr->name << "\");\n";
2087 
2088   // Add code to check attributes for equality with their default values.
2089   // Default-valued attributes will not be printed when their value matches the
2090   // default.
2091   for (const NamedAttribute &namedAttr : op.getAttributes()) {
2092     const Attribute &attr = namedAttr.attr;
2093     if (!attr.isDerivedAttr() && attr.hasDefaultValue()) {
2094       const StringRef &name = namedAttr.name;
2095       FmtContext fctx;
2096       fctx.withBuilder("odsBuilder");
2097       std::string defaultValue = std::string(
2098           tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
2099       body << "  {\n";
2100       body << "     ::mlir::Builder odsBuilder(getContext());\n";
2101       body << "     ::mlir::Attribute attr = " << op.getGetterName(name)
2102            << "Attr();\n";
2103       body << "     if(attr && (attr == " << defaultValue << "))\n";
2104       body << "       elidedAttrs.push_back(\"" << name << "\");\n";
2105       body << "  }\n";
2106     }
2107   }
2108   if (fmt.hasPropDict)
2109     body << "  _odsPrinter.printOptionalAttrDict"
2110          << (withKeyword ? "WithKeyword" : "")
2111          << "(llvm::to_vector((*this)->getDiscardableAttrs()), elidedAttrs);\n";
2112   else
2113     body << "  _odsPrinter.printOptionalAttrDict"
2114          << (withKeyword ? "WithKeyword" : "")
2115          << "((*this)->getAttrs(), elidedAttrs);\n";
2116 }
2117 
2118 /// Generate the printer for a literal value. `shouldEmitSpace` is true if a
2119 /// space should be emitted before this element. `lastWasPunctuation` is true if
2120 /// the previous element was a punctuation literal.
2121 static void genLiteralPrinter(StringRef value, MethodBody &body,
2122                               bool &shouldEmitSpace, bool &lastWasPunctuation) {
2123   body << "  _odsPrinter";
2124 
2125   // Don't insert a space for certain punctuation.
2126   if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation))
2127     body << " << ' '";
2128   body << " << \"" << value << "\";\n";
2129 
2130   // Insert a space after certain literals.
2131   shouldEmitSpace =
2132       value.size() != 1 || !StringRef("<({[").contains(value.front());
2133   lastWasPunctuation = value.front() != '_' && !isalpha(value.front());
2134 }
2135 
2136 /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
2137 /// are set to false.
2138 static void genSpacePrinter(bool value, MethodBody &body, bool &shouldEmitSpace,
2139                             bool &lastWasPunctuation) {
2140   if (value) {
2141     body << "  _odsPrinter << ' ';\n";
2142     lastWasPunctuation = false;
2143   } else {
2144     lastWasPunctuation = true;
2145   }
2146   shouldEmitSpace = false;
2147 }
2148 
2149 /// Generate the printer for a custom directive parameter.
2150 static void genCustomDirectiveParameterPrinter(FormatElement *element,
2151                                                const Operator &op,
2152                                                MethodBody &body) {
2153   if (auto *attr = dyn_cast<AttributeVariable>(element)) {
2154     body << op.getGetterName(attr->getVar()->name) << "Attr()";
2155 
2156   } else if (isa<AttrDictDirective>(element)) {
2157     body << "getOperation()->getAttrDictionary()";
2158 
2159   } else if (isa<PropDictDirective>(element)) {
2160     body << "getProperties()";
2161 
2162   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
2163     body << op.getGetterName(operand->getVar()->name) << "()";
2164 
2165   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
2166     body << op.getGetterName(region->getVar()->name) << "()";
2167 
2168   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
2169     body << op.getGetterName(successor->getVar()->name) << "()";
2170 
2171   } else if (auto *dir = dyn_cast<RefDirective>(element)) {
2172     genCustomDirectiveParameterPrinter(dir->getArg(), op, body);
2173 
2174   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
2175     auto *typeOperand = dir->getArg();
2176     auto *operand = dyn_cast<OperandVariable>(typeOperand);
2177     auto *var = operand ? operand->getVar()
2178                         : cast<ResultVariable>(typeOperand)->getVar();
2179     std::string name = op.getGetterName(var->name);
2180     if (var->isVariadic())
2181       body << name << "().getTypes()";
2182     else if (var->isOptional())
2183       body << formatv("({0}() ? {0}().getType() : ::mlir::Type())", name);
2184     else
2185       body << name << "().getType()";
2186 
2187   } else if (auto *string = dyn_cast<StringElement>(element)) {
2188     FmtContext ctx;
2189     ctx.withBuilder("::mlir::Builder(getContext())");
2190     ctx.addSubst("_ctxt", "getContext()");
2191     body << tgfmt(string->getValue(), &ctx);
2192 
2193   } else if (auto *property = dyn_cast<PropertyVariable>(element)) {
2194     FmtContext ctx;
2195     const NamedProperty *namedProperty = property->getVar();
2196     ctx.addSubst("_storage", "getProperties()." + namedProperty->name);
2197     body << tgfmt(namedProperty->prop.getConvertFromStorageCall(), &ctx);
2198   } else {
2199     llvm_unreachable("unknown custom directive parameter");
2200   }
2201 }
2202 
2203 /// Generate the printer for a custom directive.
2204 static void genCustomDirectivePrinter(CustomDirective *customDir,
2205                                       const Operator &op, MethodBody &body) {
2206   body << "  print" << customDir->getName() << "(_odsPrinter, *this";
2207   for (FormatElement *param : customDir->getArguments()) {
2208     body << ", ";
2209     genCustomDirectiveParameterPrinter(param, op, body);
2210   }
2211   body << ");\n";
2212 }
2213 
2214 /// Generate the printer for a region with the given variable name.
2215 static void genRegionPrinter(const Twine &regionName, MethodBody &body,
2216                              bool hasImplicitTermTrait) {
2217   if (hasImplicitTermTrait)
2218     body << formatv(regionSingleBlockImplicitTerminatorPrinterCode, regionName);
2219   else
2220     body << "  _odsPrinter.printRegion(" << regionName << ");\n";
2221 }
2222 static void genVariadicRegionPrinter(const Twine &regionListName,
2223                                      MethodBody &body,
2224                                      bool hasImplicitTermTrait) {
2225   body << "    llvm::interleaveComma(" << regionListName
2226        << ", _odsPrinter, [&](::mlir::Region &region) {\n      ";
2227   genRegionPrinter("region", body, hasImplicitTermTrait);
2228   body << "    });\n";
2229 }
2230 
2231 /// Generate the C++ for an operand to a (*-)type directive.
2232 static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op,
2233                                          MethodBody &body,
2234                                          bool useArrayRef = true) {
2235   if (isa<OperandsDirective>(arg))
2236     return body << "getOperation()->getOperandTypes()";
2237   if (isa<ResultsDirective>(arg))
2238     return body << "getOperation()->getResultTypes()";
2239   auto *operand = dyn_cast<OperandVariable>(arg);
2240   auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
2241   if (var->isVariadicOfVariadic())
2242     return body << formatv("{0}().join().getTypes()",
2243                            op.getGetterName(var->name));
2244   if (var->isVariadic())
2245     return body << op.getGetterName(var->name) << "().getTypes()";
2246   if (var->isOptional())
2247     return body << formatv(
2248                "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
2249                "::llvm::ArrayRef<::mlir::Type>())",
2250                op.getGetterName(var->name));
2251   if (useArrayRef)
2252     return body << "::llvm::ArrayRef<::mlir::Type>("
2253                 << op.getGetterName(var->name) << "().getType())";
2254   return body << op.getGetterName(var->name) << "().getType()";
2255 }
2256 
2257 /// Generate the printer for an enum attribute.
2258 static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
2259                                MethodBody &body) {
2260   Attribute baseAttr = var->attr.getBaseAttr();
2261   const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
2262   std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
2263 
2264   body << formatv(enumAttrBeginPrinterCode,
2265                   (var->attr.isOptional() ? "*" : "") +
2266                       op.getGetterName(var->name),
2267                   enumAttr.getSymbolToStringFnName());
2268 
2269   // Get a string containing all of the cases that can't be represented with a
2270   // keyword.
2271   BitVector nonKeywordCases(cases.size());
2272   for (auto it : llvm::enumerate(cases)) {
2273     if (!canFormatStringAsKeyword(it.value().getStr()))
2274       nonKeywordCases.set(it.index());
2275   }
2276 
2277   // Otherwise if this is a bit enum attribute, don't allow cases that may
2278   // overlap with other cases. For simplicity sake, only allow cases with a
2279   // single bit value.
2280   if (enumAttr.isBitEnum()) {
2281     for (auto it : llvm::enumerate(cases)) {
2282       int64_t value = it.value().getValue();
2283       if (value < 0 || !llvm::isPowerOf2_64(value))
2284         nonKeywordCases.set(it.index());
2285     }
2286   }
2287 
2288   // If there are any cases that can't be used with a keyword, switch on the
2289   // case value to determine when to print in the string form.
2290   if (nonKeywordCases.any()) {
2291     body << "    switch (caseValue) {\n";
2292     StringRef cppNamespace = enumAttr.getCppNamespace();
2293     StringRef enumName = enumAttr.getEnumClassName();
2294     for (auto it : llvm::enumerate(cases)) {
2295       if (nonKeywordCases.test(it.index()))
2296         continue;
2297       StringRef symbol = it.value().getSymbol();
2298       body << formatv("    case {0}::{1}::{2}:\n", cppNamespace, enumName,
2299                       llvm::isDigit(symbol.front()) ? ("_" + symbol) : symbol);
2300     }
2301     body << "      _odsPrinter << caseValueStr;\n"
2302             "      break;\n"
2303             "    default:\n"
2304             "      _odsPrinter << '\"' << caseValueStr << '\"';\n"
2305             "      break;\n"
2306             "    }\n"
2307             "  }\n";
2308     return;
2309   }
2310 
2311   body << "    _odsPrinter << caseValueStr;\n"
2312           "  }\n";
2313 }
2314 
2315 /// Generate the check for the anchor of an optional group.
2316 static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
2317                                           const Operator &op,
2318                                           MethodBody &body) {
2319   TypeSwitch<FormatElement *>(anchor)
2320       .Case<OperandVariable, ResultVariable>([&](auto *element) {
2321         const NamedTypeConstraint *var = element->getVar();
2322         std::string name = op.getGetterName(var->name);
2323         if (var->isOptional())
2324           body << name << "()";
2325         else if (var->isVariadic())
2326           body << "!" << name << "().empty()";
2327       })
2328       .Case([&](RegionVariable *element) {
2329         const NamedRegion *var = element->getVar();
2330         std::string name = op.getGetterName(var->name);
2331         // TODO: Add a check for optional regions here when ODS supports it.
2332         body << "!" << name << "().empty()";
2333       })
2334       .Case([&](TypeDirective *element) {
2335         genOptionalGroupPrinterAnchor(element->getArg(), op, body);
2336       })
2337       .Case([&](FunctionalTypeDirective *element) {
2338         genOptionalGroupPrinterAnchor(element->getInputs(), op, body);
2339       })
2340       .Case([&](AttributeVariable *element) {
2341         // Consider a default-valued attribute as present if it's not the
2342         // default value and an optional one present if it is set.
2343         genNonDefaultValueCheck(body, op, *element);
2344       })
2345       .Case([&](PropertyVariable *element) {
2346         genNonDefaultValueCheck(body, op, *element);
2347       })
2348       .Case([&](CustomDirective *ele) {
2349         body << '(';
2350         llvm::interleave(
2351             ele->getArguments(), body,
2352             [&](FormatElement *child) {
2353               body << '(';
2354               genOptionalGroupPrinterAnchor(child, op, body);
2355               body << ')';
2356             },
2357             " || ");
2358         body << ')';
2359       });
2360 }
2361 
2362 void collect(FormatElement *element,
2363              SmallVectorImpl<VariableElement *> &variables) {
2364   TypeSwitch<FormatElement *>(element)
2365       .Case([&](VariableElement *var) { variables.emplace_back(var); })
2366       .Case([&](CustomDirective *ele) {
2367         for (FormatElement *arg : ele->getArguments())
2368           collect(arg, variables);
2369       })
2370       .Case([&](OptionalElement *ele) {
2371         for (FormatElement *arg : ele->getThenElements())
2372           collect(arg, variables);
2373         for (FormatElement *arg : ele->getElseElements())
2374           collect(arg, variables);
2375       })
2376       .Case([&](FunctionalTypeDirective *funcType) {
2377         collect(funcType->getInputs(), variables);
2378         collect(funcType->getResults(), variables);
2379       })
2380       .Case([&](OIListElement *oilist) {
2381         for (ArrayRef<FormatElement *> arg : oilist->getParsingElements())
2382           for (FormatElement *arg : arg)
2383             collect(arg, variables);
2384       });
2385 }
2386 
2387 void OperationFormat::genElementPrinter(FormatElement *element,
2388                                         MethodBody &body, Operator &op,
2389                                         bool &shouldEmitSpace,
2390                                         bool &lastWasPunctuation) {
2391   if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
2392     return genLiteralPrinter(literal->getSpelling(), body, shouldEmitSpace,
2393                              lastWasPunctuation);
2394 
2395   // Emit a whitespace element.
2396   if (auto *space = dyn_cast<WhitespaceElement>(element)) {
2397     if (space->getValue() == "\\n") {
2398       body << "  _odsPrinter.printNewline();\n";
2399     } else {
2400       genSpacePrinter(!space->getValue().empty(), body, shouldEmitSpace,
2401                       lastWasPunctuation);
2402     }
2403     return;
2404   }
2405 
2406   // Emit an optional group.
2407   if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
2408     // Emit the check for the presence of the anchor element.
2409     FormatElement *anchor = optional->getAnchor();
2410     body << "  if (";
2411     if (optional->isInverted())
2412       body << "!";
2413     genOptionalGroupPrinterAnchor(anchor, op, body);
2414     body << ") {\n";
2415     body.indent();
2416 
2417     // If the anchor is a unit attribute, we don't need to print it. When
2418     // parsing, we will add this attribute if this group is present.
2419     ArrayRef<FormatElement *> thenElements = optional->getThenElements();
2420     ArrayRef<FormatElement *> elseElements = optional->getElseElements();
2421     FormatElement *elidedAnchorElement = nullptr;
2422     auto *anchorAttr = dyn_cast<AttributeLikeVariable>(anchor);
2423     if (anchorAttr && anchorAttr != thenElements.front() &&
2424         (elseElements.empty() || anchorAttr != elseElements.front()) &&
2425         anchorAttr->isUnit()) {
2426       elidedAnchorElement = anchorAttr;
2427     }
2428     auto genElementPrinters = [&](ArrayRef<FormatElement *> elements) {
2429       for (FormatElement *childElement : elements) {
2430         if (childElement != elidedAnchorElement) {
2431           genElementPrinter(childElement, body, op, shouldEmitSpace,
2432                             lastWasPunctuation);
2433         }
2434       }
2435     };
2436 
2437     // Emit each of the elements.
2438     genElementPrinters(thenElements);
2439     body << "}";
2440 
2441     // Emit each of the else elements.
2442     if (!elseElements.empty()) {
2443       body << " else {\n";
2444       genElementPrinters(elseElements);
2445       body << "}";
2446     }
2447 
2448     body.unindent() << "\n";
2449     return;
2450   }
2451 
2452   // Emit the OIList
2453   if (auto *oilist = dyn_cast<OIListElement>(element)) {
2454     for (auto clause : oilist->getClauses()) {
2455       LiteralElement *lelement = std::get<0>(clause);
2456       ArrayRef<FormatElement *> pelement = std::get<1>(clause);
2457 
2458       SmallVector<VariableElement *> vars;
2459       for (FormatElement *el : pelement)
2460         collect(el, vars);
2461       body << "  if (false";
2462       for (VariableElement *var : vars) {
2463         TypeSwitch<FormatElement *>(var)
2464             .Case([&](AttributeVariable *attrEle) {
2465               body << " || (";
2466               genNonDefaultValueCheck(body, op, *attrEle);
2467               body << ")";
2468             })
2469             .Case([&](PropertyVariable *propEle) {
2470               body << " || (";
2471               genNonDefaultValueCheck(body, op, *propEle);
2472               body << ")";
2473             })
2474             .Case([&](OperandVariable *ele) {
2475               if (ele->getVar()->isVariadic()) {
2476                 body << " || " << op.getGetterName(ele->getVar()->name)
2477                      << "().size()";
2478               } else {
2479                 body << " || " << op.getGetterName(ele->getVar()->name) << "()";
2480               }
2481             })
2482             .Case([&](ResultVariable *ele) {
2483               if (ele->getVar()->isVariadic()) {
2484                 body << " || " << op.getGetterName(ele->getVar()->name)
2485                      << "().size()";
2486               } else {
2487                 body << " || " << op.getGetterName(ele->getVar()->name) << "()";
2488               }
2489             })
2490             .Case([&](RegionVariable *reg) {
2491               body << " || " << op.getGetterName(reg->getVar()->name) << "()";
2492             });
2493       }
2494 
2495       body << ") {\n";
2496       genLiteralPrinter(lelement->getSpelling(), body, shouldEmitSpace,
2497                         lastWasPunctuation);
2498       if (oilist->getUnitVariableParsingElement(pelement) == nullptr) {
2499         for (FormatElement *element : pelement)
2500           genElementPrinter(element, body, op, shouldEmitSpace,
2501                             lastWasPunctuation);
2502       }
2503       body << "  }\n";
2504     }
2505     return;
2506   }
2507 
2508   // Emit the attribute dictionary.
2509   if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
2510     genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword());
2511     lastWasPunctuation = false;
2512     return;
2513   }
2514 
2515   // Emit the property dictionary.
2516   if (isa<PropDictDirective>(element)) {
2517     genPropDictPrinter(*this, op, body);
2518     lastWasPunctuation = false;
2519     return;
2520   }
2521 
2522   // Optionally insert a space before the next element. The AttrDict printer
2523   // already adds a space as necessary.
2524   if (shouldEmitSpace || !lastWasPunctuation)
2525     body << "  _odsPrinter << ' ';\n";
2526   lastWasPunctuation = false;
2527   shouldEmitSpace = true;
2528 
2529   if (auto *attr = dyn_cast<AttributeVariable>(element)) {
2530     const NamedAttribute *var = attr->getVar();
2531 
2532     // If we are formatting as an enum, symbolize the attribute as a string.
2533     if (canFormatEnumAttr(var))
2534       return genEnumAttrPrinter(var, op, body);
2535 
2536     // If we are formatting as a symbol name, handle it as a symbol name.
2537     if (shouldFormatSymbolNameAttr(var)) {
2538       body << "  _odsPrinter.printSymbolName(" << op.getGetterName(var->name)
2539            << "Attr().getValue());\n";
2540       return;
2541     }
2542 
2543     // Elide the attribute type if it is buildable.
2544     if (attr->getTypeBuilder())
2545       body << "  _odsPrinter.printAttributeWithoutType("
2546            << op.getGetterName(var->name) << "Attr());\n";
2547     else if (attr->shouldBeQualified() ||
2548              var->attr.getStorageType() == "::mlir::Attribute")
2549       body << "  _odsPrinter.printAttribute(" << op.getGetterName(var->name)
2550            << "Attr());\n";
2551     else
2552       body << "_odsPrinter.printStrippedAttrOrType("
2553            << op.getGetterName(var->name) << "Attr());\n";
2554   } else if (auto *property = dyn_cast<PropertyVariable>(element)) {
2555     const NamedProperty *var = property->getVar();
2556     FmtContext fmtContext;
2557     fmtContext.addSubst("_printer", "_odsPrinter");
2558     fmtContext.addSubst("_ctxt", "getContext()");
2559     fmtContext.addSubst("_storage", "getProperties()." + var->name);
2560     body << tgfmt(var->prop.getPrinterCall(), &fmtContext) << ";\n";
2561   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
2562     if (operand->getVar()->isVariadicOfVariadic()) {
2563       body << "  ::llvm::interleaveComma("
2564            << op.getGetterName(operand->getVar()->name)
2565            << "(), _odsPrinter, [&](const auto &operands) { _odsPrinter << "
2566               "\"(\" << operands << "
2567               "\")\"; });\n";
2568 
2569     } else if (operand->getVar()->isOptional()) {
2570       body << "  if (::mlir::Value value = "
2571            << op.getGetterName(operand->getVar()->name) << "())\n"
2572            << "    _odsPrinter << value;\n";
2573     } else {
2574       body << "  _odsPrinter << " << op.getGetterName(operand->getVar()->name)
2575            << "();\n";
2576     }
2577   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
2578     const NamedRegion *var = region->getVar();
2579     std::string name = op.getGetterName(var->name);
2580     if (var->isVariadic()) {
2581       genVariadicRegionPrinter(name + "()", body, hasImplicitTermTrait);
2582     } else {
2583       genRegionPrinter(name + "()", body, hasImplicitTermTrait);
2584     }
2585   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
2586     const NamedSuccessor *var = successor->getVar();
2587     std::string name = op.getGetterName(var->name);
2588     if (var->isVariadic())
2589       body << "  ::llvm::interleaveComma(" << name << "(), _odsPrinter);\n";
2590     else
2591       body << "  _odsPrinter << " << name << "();\n";
2592   } else if (auto *dir = dyn_cast<CustomDirective>(element)) {
2593     genCustomDirectivePrinter(dir, op, body);
2594   } else if (isa<OperandsDirective>(element)) {
2595     body << "  _odsPrinter << getOperation()->getOperands();\n";
2596   } else if (isa<RegionsDirective>(element)) {
2597     genVariadicRegionPrinter("getOperation()->getRegions()", body,
2598                              hasImplicitTermTrait);
2599   } else if (isa<SuccessorsDirective>(element)) {
2600     body << "  ::llvm::interleaveComma(getOperation()->getSuccessors(), "
2601             "_odsPrinter);\n";
2602   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
2603     if (auto *operand = dyn_cast<OperandVariable>(dir->getArg())) {
2604       if (operand->getVar()->isVariadicOfVariadic()) {
2605         body << formatv(
2606             "  ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, "
2607             "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << "
2608             "types << \")\"; });\n",
2609             op.getGetterName(operand->getVar()->name));
2610         return;
2611       }
2612     }
2613     const NamedTypeConstraint *var = nullptr;
2614     {
2615       if (auto *operand = dyn_cast<OperandVariable>(dir->getArg()))
2616         var = operand->getVar();
2617       else if (auto *operand = dyn_cast<ResultVariable>(dir->getArg()))
2618         var = operand->getVar();
2619     }
2620     if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
2621         !var->isOptional()) {
2622       StringRef cppType = var->constraint.getCppType();
2623       if (dir->shouldBeQualified()) {
2624         body << "   _odsPrinter << " << op.getGetterName(var->name)
2625              << "().getType();\n";
2626         return;
2627       }
2628       body << "  {\n"
2629            << "    auto type = " << op.getGetterName(var->name)
2630            << "().getType();\n"
2631            << "    if (auto validType = ::llvm::dyn_cast<" << cppType
2632            << ">(type))\n"
2633            << "      _odsPrinter.printStrippedAttrOrType(validType);\n"
2634            << "   else\n"
2635            << "     _odsPrinter << type;\n"
2636            << "  }\n";
2637       return;
2638     }
2639     body << "  _odsPrinter << ";
2640     genTypeOperandPrinter(dir->getArg(), op, body, /*useArrayRef=*/false)
2641         << ";\n";
2642   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
2643     body << "  _odsPrinter.printFunctionalType(";
2644     genTypeOperandPrinter(dir->getInputs(), op, body) << ", ";
2645     genTypeOperandPrinter(dir->getResults(), op, body) << ");\n";
2646   } else {
2647     llvm_unreachable("unknown format element");
2648   }
2649 }
2650 
2651 void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
2652   auto *method = opClass.addMethod(
2653       "void", "print",
2654       MethodParameter("::mlir::OpAsmPrinter &", "_odsPrinter"));
2655   auto &body = method->body();
2656 
2657   // Flags for if we should emit a space, and if the last element was
2658   // punctuation.
2659   bool shouldEmitSpace = true, lastWasPunctuation = false;
2660   for (FormatElement *element : elements)
2661     genElementPrinter(element, body, op, shouldEmitSpace, lastWasPunctuation);
2662 }
2663 
2664 //===----------------------------------------------------------------------===//
2665 // OpFormatParser
2666 //===----------------------------------------------------------------------===//
2667 
2668 /// Function to find an element within the given range that has the same name as
2669 /// 'name'.
2670 template <typename RangeT>
2671 static auto findArg(RangeT &&range, StringRef name) {
2672   auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
2673   return it != range.end() ? &*it : nullptr;
2674 }
2675 
2676 namespace {
2677 /// This class implements a parser for an instance of an operation assembly
2678 /// format.
2679 class OpFormatParser : public FormatParser {
2680 public:
2681   OpFormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op)
2682       : FormatParser(mgr, op.getLoc()[0]), fmt(format), op(op),
2683         seenOperandTypes(op.getNumOperands()),
2684         seenResultTypes(op.getNumResults()) {}
2685 
2686 protected:
2687   /// Verify the format elements.
2688   LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override;
2689   /// Verify the arguments to a custom directive.
2690   LogicalResult
2691   verifyCustomDirectiveArguments(SMLoc loc,
2692                                  ArrayRef<FormatElement *> arguments) override;
2693   /// Verify the elements of an optional group.
2694   LogicalResult verifyOptionalGroupElements(SMLoc loc,
2695                                             ArrayRef<FormatElement *> elements,
2696                                             FormatElement *anchor) override;
2697   LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element,
2698                                            bool isAnchor);
2699 
2700   LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
2701 
2702   /// Parse an operation variable.
2703   FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
2704                                                Context ctx) override;
2705   /// Parse an operation format directive.
2706   FailureOr<FormatElement *>
2707   parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override;
2708 
2709 private:
2710   /// This struct represents a type resolution instance. It includes a specific
2711   /// type as well as an optional transformer to apply to that type in order to
2712   /// properly resolve the type of a variable.
2713   struct TypeResolutionInstance {
2714     ConstArgument resolver;
2715     std::optional<StringRef> transformer;
2716   };
2717 
2718   /// Verify the state of operation attributes within the format.
2719   LogicalResult verifyAttributes(SMLoc loc, ArrayRef<FormatElement *> elements);
2720 
2721   /// Verify that attributes elements aren't followed by colon literals.
2722   LogicalResult verifyAttributeColonType(SMLoc loc,
2723                                          ArrayRef<FormatElement *> elements);
2724   /// Verify that the attribute dictionary directive isn't followed by a region.
2725   LogicalResult verifyAttrDictRegion(SMLoc loc,
2726                                      ArrayRef<FormatElement *> elements);
2727 
2728   /// Verify the state of operation operands within the format.
2729   LogicalResult
2730   verifyOperands(SMLoc loc,
2731                  StringMap<TypeResolutionInstance> &variableTyResolver);
2732 
2733   /// Verify the state of operation regions within the format.
2734   LogicalResult verifyRegions(SMLoc loc);
2735 
2736   /// Verify the state of operation results within the format.
2737   LogicalResult
2738   verifyResults(SMLoc loc,
2739                 StringMap<TypeResolutionInstance> &variableTyResolver);
2740 
2741   /// Verify the state of operation successors within the format.
2742   LogicalResult verifySuccessors(SMLoc loc);
2743 
2744   LogicalResult verifyOIListElements(SMLoc loc,
2745                                      ArrayRef<FormatElement *> elements);
2746 
2747   /// Given the values of an `AllTypesMatch` trait, check for inferable type
2748   /// resolution.
2749   void handleAllTypesMatchConstraint(
2750       ArrayRef<StringRef> values,
2751       StringMap<TypeResolutionInstance> &variableTyResolver);
2752   /// Check for inferable type resolution given all operands, and or results,
2753   /// have the same type. If 'includeResults' is true, the results also have the
2754   /// same type as all of the operands.
2755   void handleSameTypesConstraint(
2756       StringMap<TypeResolutionInstance> &variableTyResolver,
2757       bool includeResults);
2758   /// Check for inferable type resolution based on another operand, result, or
2759   /// attribute.
2760   void handleTypesMatchConstraint(
2761       StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);
2762 
2763   /// Returns an argument or attribute with the given name that has been seen
2764   /// within the format.
2765   ConstArgument findSeenArg(StringRef name);
2766 
2767   /// Parse the various different directives.
2768   FailureOr<FormatElement *> parsePropDictDirective(SMLoc loc, Context context);
2769   FailureOr<FormatElement *> parseAttrDictDirective(SMLoc loc, Context context,
2770                                                     bool withKeyword);
2771   FailureOr<FormatElement *> parseFunctionalTypeDirective(SMLoc loc,
2772                                                           Context context);
2773   FailureOr<FormatElement *> parseOIListDirective(SMLoc loc, Context context);
2774   LogicalResult verifyOIListParsingElement(FormatElement *element, SMLoc loc);
2775   FailureOr<FormatElement *> parseOperandsDirective(SMLoc loc, Context context);
2776   FailureOr<FormatElement *> parseRegionsDirective(SMLoc loc, Context context);
2777   FailureOr<FormatElement *> parseResultsDirective(SMLoc loc, Context context);
2778   FailureOr<FormatElement *> parseSuccessorsDirective(SMLoc loc,
2779                                                       Context context);
2780   FailureOr<FormatElement *> parseTypeDirective(SMLoc loc, Context context);
2781   FailureOr<FormatElement *> parseTypeDirectiveOperand(SMLoc loc,
2782                                                        bool isRefChild = false);
2783 
2784   //===--------------------------------------------------------------------===//
2785   // Fields
2786   //===--------------------------------------------------------------------===//
2787 
2788   OperationFormat &fmt;
2789   Operator &op;
2790 
2791   // The following are various bits of format state used for verification
2792   // during parsing.
2793   bool hasAttrDict = false;
2794   bool hasPropDict = false;
2795   bool hasAllRegions = false, hasAllSuccessors = false;
2796   bool canInferResultTypes = false;
2797   llvm::SmallBitVector seenOperandTypes, seenResultTypes;
2798   llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
2799   llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
2800   llvm::DenseSet<const NamedRegion *> seenRegions;
2801   llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
2802   llvm::SmallSetVector<const NamedProperty *, 8> seenProperties;
2803 };
2804 } // namespace
2805 
2806 LogicalResult OpFormatParser::verify(SMLoc loc,
2807                                      ArrayRef<FormatElement *> elements) {
2808   // Check that the attribute dictionary is in the format.
2809   if (!hasAttrDict)
2810     return emitError(loc, "'attr-dict' directive not found in "
2811                           "custom assembly format");
2812 
2813   // Check for any type traits that we can use for inferring types.
2814   StringMap<TypeResolutionInstance> variableTyResolver;
2815   for (const Trait &trait : op.getTraits()) {
2816     const Record &def = trait.getDef();
2817     if (def.isSubClassOf("AllTypesMatch")) {
2818       handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
2819                                     variableTyResolver);
2820     } else if (def.getName() == "SameTypeOperands") {
2821       handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
2822     } else if (def.getName() == "SameOperandsAndResultType") {
2823       handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
2824     } else if (def.isSubClassOf("TypesMatchWith")) {
2825       handleTypesMatchConstraint(variableTyResolver, def);
2826     } else if (!op.allResultTypesKnown()) {
2827       // This doesn't check the name directly to handle
2828       //    DeclareOpInterfaceMethods<InferTypeOpInterface>
2829       // and the like.
2830       // TODO: Add hasCppInterface check.
2831       if (auto name = def.getValueAsOptionalString("cppInterfaceName")) {
2832         if (*name == "InferTypeOpInterface" &&
2833             def.getValueAsString("cppNamespace") == "::mlir")
2834           canInferResultTypes = true;
2835       }
2836     }
2837   }
2838 
2839   // Verify the state of the various operation components.
2840   if (failed(verifyAttributes(loc, elements)) ||
2841       failed(verifyResults(loc, variableTyResolver)) ||
2842       failed(verifyOperands(loc, variableTyResolver)) ||
2843       failed(verifyRegions(loc)) || failed(verifySuccessors(loc)) ||
2844       failed(verifyOIListElements(loc, elements)))
2845     return failure();
2846 
2847   // Collect the set of used attributes in the format.
2848   fmt.usedAttributes = std::move(seenAttrs);
2849   fmt.usedProperties = std::move(seenProperties);
2850 
2851   // Set whether prop-dict is used in the format
2852   fmt.hasPropDict = hasPropDict;
2853   return success();
2854 }
2855 
2856 LogicalResult
2857 OpFormatParser::verifyAttributes(SMLoc loc,
2858                                  ArrayRef<FormatElement *> elements) {
2859   // Check that there are no `:` literals after an attribute without a constant
2860   // type. The attribute grammar contains an optional trailing colon type, which
2861   // can lead to unexpected and generally unintended behavior. Given that, it is
2862   // better to just error out here instead.
2863   if (failed(verifyAttributeColonType(loc, elements)))
2864     return failure();
2865   // Check that there are no region variables following an attribute dicitonary.
2866   // Both start with `{` and so the optional attribute dictionary can cause
2867   // format ambiguities.
2868   if (failed(verifyAttrDictRegion(loc, elements)))
2869     return failure();
2870 
2871   // Check for VariadicOfVariadic variables. The segment attribute of those
2872   // variables will be infered.
2873   for (const NamedTypeConstraint *var : seenOperands) {
2874     if (var->constraint.isVariadicOfVariadic()) {
2875       fmt.inferredAttributes.insert(
2876           var->constraint.getVariadicOfVariadicSegmentSizeAttr());
2877     }
2878   }
2879 
2880   return success();
2881 }
2882 
2883 /// Returns whether the single format element is optionally parsed.
2884 static bool isOptionallyParsed(FormatElement *el) {
2885   if (auto *attrVar = dyn_cast<AttributeVariable>(el)) {
2886     Attribute attr = attrVar->getVar()->attr;
2887     return attr.isOptional() || attr.hasDefaultValue();
2888   }
2889   if (auto *propVar = dyn_cast<PropertyVariable>(el)) {
2890     const Property &prop = propVar->getVar()->prop;
2891     return prop.hasDefaultValue() && prop.hasOptionalParser();
2892   }
2893   if (auto *operandVar = dyn_cast<OperandVariable>(el)) {
2894     const NamedTypeConstraint *operand = operandVar->getVar();
2895     return operand->isOptional() || operand->isVariadic() ||
2896            operand->isVariadicOfVariadic();
2897   }
2898   if (auto *successorVar = dyn_cast<SuccessorVariable>(el))
2899     return successorVar->getVar()->isVariadic();
2900   if (auto *regionVar = dyn_cast<RegionVariable>(el))
2901     return regionVar->getVar()->isVariadic();
2902   return isa<WhitespaceElement, AttrDictDirective>(el);
2903 }
2904 
2905 /// Scan the given range of elements from the start for an invalid format
2906 /// element that satisfies `isInvalid`, skipping any optionally-parsed elements.
2907 /// If an optional group is encountered, this function recurses into the 'then'
2908 /// and 'else' elements to check if they are invalid. Returns `success` if the
2909 /// range is known to be valid or `std::nullopt` if scanning reached the end.
2910 ///
2911 /// Since the guard element of an optional group is required, this function
2912 /// accepts an optional element pointer to mark it as required.
2913 static std::optional<LogicalResult> checkRangeForElement(
2914     FormatElement *base,
2915     function_ref<bool(FormatElement *, FormatElement *)> isInvalid,
2916     iterator_range<ArrayRef<FormatElement *>::iterator> elementRange,
2917     FormatElement *optionalGuard = nullptr) {
2918   for (FormatElement *element : elementRange) {
2919     // If we encounter an invalid element, return an error.
2920     if (isInvalid(base, element))
2921       return failure();
2922 
2923     // Recurse on optional groups.
2924     if (auto *optional = dyn_cast<OptionalElement>(element)) {
2925       if (std::optional<LogicalResult> result = checkRangeForElement(
2926               base, isInvalid, optional->getThenElements(),
2927               // The optional group guard is required for the group.
2928               optional->getThenElements().front()))
2929         if (failed(*result))
2930           return failure();
2931       if (std::optional<LogicalResult> result = checkRangeForElement(
2932               base, isInvalid, optional->getElseElements()))
2933         if (failed(*result))
2934           return failure();
2935       // Skip the optional group.
2936       continue;
2937     }
2938 
2939     // Skip optionally parsed elements.
2940     if (element != optionalGuard && isOptionallyParsed(element))
2941       continue;
2942 
2943     // We found a closing element that is valid.
2944     return success();
2945   }
2946   // Return std::nullopt to indicate that we reached the end.
2947   return std::nullopt;
2948 }
2949 
2950 /// For the given elements, check whether any attributes are followed by a colon
2951 /// literal, resulting in an ambiguous assembly format. Returns a non-null
2952 /// attribute if verification of said attribute reached the end of the range.
2953 /// Returns null if all attribute elements are verified.
2954 static FailureOr<FormatElement *> verifyAdjacentElements(
2955     function_ref<bool(FormatElement *)> isBase,
2956     function_ref<bool(FormatElement *, FormatElement *)> isInvalid,
2957     ArrayRef<FormatElement *> elements) {
2958   for (auto *it = elements.begin(), *e = elements.end(); it != e; ++it) {
2959     // The current attribute being verified.
2960     FormatElement *base;
2961 
2962     if (isBase(*it)) {
2963       base = *it;
2964     } else if (auto *optional = dyn_cast<OptionalElement>(*it)) {
2965       // Recurse on optional groups.
2966       FailureOr<FormatElement *> thenResult = verifyAdjacentElements(
2967           isBase, isInvalid, optional->getThenElements());
2968       if (failed(thenResult))
2969         return failure();
2970       FailureOr<FormatElement *> elseResult = verifyAdjacentElements(
2971           isBase, isInvalid, optional->getElseElements());
2972       if (failed(elseResult))
2973         return failure();
2974       // If either optional group has an unverified attribute, save it.
2975       // Otherwise, move on to the next element.
2976       if (!(base = *thenResult) && !(base = *elseResult))
2977         continue;
2978     } else {
2979       continue;
2980     }
2981 
2982     // Verify subsequent elements for potential ambiguities.
2983     if (std::optional<LogicalResult> result =
2984             checkRangeForElement(base, isInvalid, {std::next(it), e})) {
2985       if (failed(*result))
2986         return failure();
2987     } else {
2988       // Since we reached the end, return the attribute as unverified.
2989       return base;
2990     }
2991   }
2992   // All attribute elements are known to be verified.
2993   return nullptr;
2994 }
2995 
2996 LogicalResult
2997 OpFormatParser::verifyAttributeColonType(SMLoc loc,
2998                                          ArrayRef<FormatElement *> elements) {
2999   auto isBase = [](FormatElement *el) {
3000     auto *attr = dyn_cast<AttributeVariable>(el);
3001     if (!attr)
3002       return false;
3003     // Check only attributes without type builders or that are known to call
3004     // the generic attribute parser.
3005     return !attr->getTypeBuilder() &&
3006            (attr->shouldBeQualified() ||
3007             attr->getVar()->attr.getStorageType() == "::mlir::Attribute");
3008   };
3009   auto isInvalid = [&](FormatElement *base, FormatElement *el) {
3010     auto *literal = dyn_cast<LiteralElement>(el);
3011     if (!literal || literal->getSpelling() != ":")
3012       return false;
3013     // If we encounter `:`, the range is known to be invalid.
3014     (void)emitError(
3015         loc, formatv("format ambiguity caused by `:` literal found after "
3016                      "attribute `{0}` which does not have a buildable type",
3017                      cast<AttributeVariable>(base)->getVar()->name));
3018     return true;
3019   };
3020   return verifyAdjacentElements(isBase, isInvalid, elements);
3021 }
3022 
3023 LogicalResult
3024 OpFormatParser::verifyAttrDictRegion(SMLoc loc,
3025                                      ArrayRef<FormatElement *> elements) {
3026   auto isBase = [](FormatElement *el) {
3027     if (auto *attrDict = dyn_cast<AttrDictDirective>(el))
3028       return !attrDict->isWithKeyword();
3029     return false;
3030   };
3031   auto isInvalid = [&](FormatElement *base, FormatElement *el) {
3032     auto *region = dyn_cast<RegionVariable>(el);
3033     if (!region)
3034       return false;
3035     (void)emitErrorAndNote(
3036         loc,
3037         formatv("format ambiguity caused by `attr-dict` directive "
3038                 "followed by region `{0}`",
3039                 region->getVar()->name),
3040         "try using `attr-dict-with-keyword` instead");
3041     return true;
3042   };
3043   return verifyAdjacentElements(isBase, isInvalid, elements);
3044 }
3045 
3046 LogicalResult OpFormatParser::verifyOperands(
3047     SMLoc loc, StringMap<TypeResolutionInstance> &variableTyResolver) {
3048   // Check that all of the operands are within the format, and their types can
3049   // be inferred.
3050   auto &buildableTypes = fmt.buildableTypes;
3051   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
3052     NamedTypeConstraint &operand = op.getOperand(i);
3053 
3054     // Check that the operand itself is in the format.
3055     if (!fmt.allOperands && !seenOperands.count(&operand)) {
3056       return emitErrorAndNote(loc,
3057                               "operand #" + Twine(i) + ", named '" +
3058                                   operand.name + "', not found",
3059                               "suggest adding a '$" + operand.name +
3060                                   "' directive to the custom assembly format");
3061     }
3062 
3063     // Check that the operand type is in the format, or that it can be inferred.
3064     if (fmt.allOperandTypes || seenOperandTypes.test(i))
3065       continue;
3066 
3067     // Check to see if we can infer this type from another variable.
3068     auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
3069     if (varResolverIt != variableTyResolver.end()) {
3070       TypeResolutionInstance &resolver = varResolverIt->second;
3071       fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer);
3072       continue;
3073     }
3074 
3075     // Similarly to results, allow a custom builder for resolving the type if
3076     // we aren't using the 'operands' directive.
3077     std::optional<StringRef> builder = operand.constraint.getBuilderCall();
3078     if (!builder || (fmt.allOperands && operand.isVariableLength())) {
3079       return emitErrorAndNote(
3080           loc,
3081           "type of operand #" + Twine(i) + ", named '" + operand.name +
3082               "', is not buildable and a buildable type cannot be inferred",
3083           "suggest adding a type constraint to the operation or adding a "
3084           "'type($" +
3085               operand.name + ")' directive to the " + "custom assembly format");
3086     }
3087     auto it = buildableTypes.insert({*builder, buildableTypes.size()});
3088     fmt.operandTypes[i].setBuilderIdx(it.first->second);
3089   }
3090   return success();
3091 }
3092 
3093 LogicalResult OpFormatParser::verifyRegions(SMLoc loc) {
3094   // Check that all of the regions are within the format.
3095   if (hasAllRegions)
3096     return success();
3097 
3098   for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) {
3099     const NamedRegion &region = op.getRegion(i);
3100     if (!seenRegions.count(&region)) {
3101       return emitErrorAndNote(loc,
3102                               "region #" + Twine(i) + ", named '" +
3103                                   region.name + "', not found",
3104                               "suggest adding a '$" + region.name +
3105                                   "' directive to the custom assembly format");
3106     }
3107   }
3108   return success();
3109 }
3110 
3111 LogicalResult OpFormatParser::verifyResults(
3112     SMLoc loc, StringMap<TypeResolutionInstance> &variableTyResolver) {
3113   // If we format all of the types together, there is nothing to check.
3114   if (fmt.allResultTypes)
3115     return success();
3116 
3117   // If no result types are specified and we can infer them, infer all result
3118   // types
3119   if (op.getNumResults() > 0 && seenResultTypes.count() == 0 &&
3120       canInferResultTypes) {
3121     fmt.infersResultTypes = true;
3122     return success();
3123   }
3124 
3125   // Check that all of the result types can be inferred.
3126   auto &buildableTypes = fmt.buildableTypes;
3127   for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
3128     if (seenResultTypes.test(i))
3129       continue;
3130 
3131     // Check to see if we can infer this type from another variable.
3132     auto varResolverIt = variableTyResolver.find(op.getResultName(i));
3133     if (varResolverIt != variableTyResolver.end()) {
3134       TypeResolutionInstance resolver = varResolverIt->second;
3135       fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer);
3136       continue;
3137     }
3138 
3139     // If the result is not variable length, allow for the case where the type
3140     // has a builder that we can use.
3141     NamedTypeConstraint &result = op.getResult(i);
3142     std::optional<StringRef> builder = result.constraint.getBuilderCall();
3143     if (!builder || result.isVariableLength()) {
3144       return emitErrorAndNote(
3145           loc,
3146           "type of result #" + Twine(i) + ", named '" + result.name +
3147               "', is not buildable and a buildable type cannot be inferred",
3148           "suggest adding a type constraint to the operation or adding a "
3149           "'type($" +
3150               result.name + ")' directive to the " + "custom assembly format");
3151     }
3152     // Note in the format that this result uses the custom builder.
3153     auto it = buildableTypes.insert({*builder, buildableTypes.size()});
3154     fmt.resultTypes[i].setBuilderIdx(it.first->second);
3155   }
3156   return success();
3157 }
3158 
3159 LogicalResult OpFormatParser::verifySuccessors(SMLoc loc) {
3160   // Check that all of the successors are within the format.
3161   if (hasAllSuccessors)
3162     return success();
3163 
3164   for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
3165     const NamedSuccessor &successor = op.getSuccessor(i);
3166     if (!seenSuccessors.count(&successor)) {
3167       return emitErrorAndNote(loc,
3168                               "successor #" + Twine(i) + ", named '" +
3169                                   successor.name + "', not found",
3170                               "suggest adding a '$" + successor.name +
3171                                   "' directive to the custom assembly format");
3172     }
3173   }
3174   return success();
3175 }
3176 
3177 LogicalResult
3178 OpFormatParser::verifyOIListElements(SMLoc loc,
3179                                      ArrayRef<FormatElement *> elements) {
3180   // Check that all of the successors are within the format.
3181   SmallVector<StringRef> prohibitedLiterals;
3182   for (FormatElement *it : elements) {
3183     if (auto *oilist = dyn_cast<OIListElement>(it)) {
3184       if (!prohibitedLiterals.empty()) {
3185         // We just saw an oilist element in last iteration. Literals should not
3186         // match.
3187         for (LiteralElement *literal : oilist->getLiteralElements()) {
3188           if (find(prohibitedLiterals, literal->getSpelling()) !=
3189               prohibitedLiterals.end()) {
3190             return emitError(
3191                 loc, "format ambiguity because " + literal->getSpelling() +
3192                          " is used in two adjacent oilist elements.");
3193           }
3194         }
3195       }
3196       for (LiteralElement *literal : oilist->getLiteralElements())
3197         prohibitedLiterals.push_back(literal->getSpelling());
3198     } else if (auto *literal = dyn_cast<LiteralElement>(it)) {
3199       if (find(prohibitedLiterals, literal->getSpelling()) !=
3200           prohibitedLiterals.end()) {
3201         return emitError(
3202             loc,
3203             "format ambiguity because " + literal->getSpelling() +
3204                 " is used both in oilist element and the adjacent literal.");
3205       }
3206       prohibitedLiterals.clear();
3207     } else {
3208       prohibitedLiterals.clear();
3209     }
3210   }
3211   return success();
3212 }
3213 
3214 void OpFormatParser::handleAllTypesMatchConstraint(
3215     ArrayRef<StringRef> values,
3216     StringMap<TypeResolutionInstance> &variableTyResolver) {
3217   for (unsigned i = 0, e = values.size(); i != e; ++i) {
3218     // Check to see if this value matches a resolved operand or result type.
3219     ConstArgument arg = findSeenArg(values[i]);
3220     if (!arg)
3221       continue;
3222 
3223     // Mark this value as the type resolver for the other variables.
3224     for (unsigned j = 0; j != i; ++j)
3225       variableTyResolver[values[j]] = {arg, std::nullopt};
3226     for (unsigned j = i + 1; j != e; ++j)
3227       variableTyResolver[values[j]] = {arg, std::nullopt};
3228   }
3229 }
3230 
3231 void OpFormatParser::handleSameTypesConstraint(
3232     StringMap<TypeResolutionInstance> &variableTyResolver,
3233     bool includeResults) {
3234   const NamedTypeConstraint *resolver = nullptr;
3235   int resolvedIt = -1;
3236 
3237   // Check to see if there is an operand or result to use for the resolution.
3238   if ((resolvedIt = seenOperandTypes.find_first()) != -1)
3239     resolver = &op.getOperand(resolvedIt);
3240   else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1)
3241     resolver = &op.getResult(resolvedIt);
3242   else
3243     return;
3244 
3245   // Set the resolvers for each operand and result.
3246   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
3247     if (!seenOperandTypes.test(i))
3248       variableTyResolver[op.getOperand(i).name] = {resolver, std::nullopt};
3249   if (includeResults) {
3250     for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
3251       if (!seenResultTypes.test(i))
3252         variableTyResolver[op.getResultName(i)] = {resolver, std::nullopt};
3253   }
3254 }
3255 
3256 void OpFormatParser::handleTypesMatchConstraint(
3257     StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) {
3258   StringRef lhsName = def.getValueAsString("lhs");
3259   StringRef rhsName = def.getValueAsString("rhs");
3260   StringRef transformer = def.getValueAsString("transformer");
3261   if (ConstArgument arg = findSeenArg(lhsName))
3262     variableTyResolver[rhsName] = {arg, transformer};
3263 }
3264 
3265 ConstArgument OpFormatParser::findSeenArg(StringRef name) {
3266   if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
3267     return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
3268   if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
3269     return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
3270   if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
3271     return seenAttrs.count(attr) ? attr : nullptr;
3272   return nullptr;
3273 }
3274 
3275 FailureOr<FormatElement *>
3276 OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
3277   // Check that the parsed argument is something actually registered on the op.
3278   // Attributes
3279   if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
3280     if (ctx == TypeDirectiveContext)
3281       return emitError(
3282           loc, "attributes cannot be used as children to a `type` directive");
3283     if (ctx == RefDirectiveContext) {
3284       if (!seenAttrs.count(attr))
3285         return emitError(loc, "attribute '" + name +
3286                                   "' must be bound before it is referenced");
3287     } else if (!seenAttrs.insert(attr)) {
3288       return emitError(loc, "attribute '" + name + "' is already bound");
3289     }
3290 
3291     return create<AttributeVariable>(attr);
3292   }
3293 
3294   if (const NamedProperty *property = findArg(op.getProperties(), name)) {
3295     if (ctx == TypeDirectiveContext)
3296       return emitError(
3297           loc, "properties cannot be used as children to a `type` directive");
3298     if (ctx == RefDirectiveContext) {
3299       if (!seenProperties.count(property))
3300         return emitError(loc, "property '" + name +
3301                                   "' must be bound before it is referenced");
3302     } else {
3303       if (!seenProperties.insert(property))
3304         return emitError(loc, "property '" + name + "' is already bound");
3305     }
3306 
3307     return create<PropertyVariable>(property);
3308   }
3309 
3310   // Operands
3311   if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
3312     if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
3313       if (fmt.allOperands || !seenOperands.insert(operand).second)
3314         return emitError(loc, "operand '" + name + "' is already bound");
3315     } else if (ctx == RefDirectiveContext && !seenOperands.count(operand)) {
3316       return emitError(loc, "operand '" + name +
3317                                 "' must be bound before it is referenced");
3318     }
3319     return create<OperandVariable>(operand);
3320   }
3321   // Regions
3322   if (const NamedRegion *region = findArg(op.getRegions(), name)) {
3323     if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
3324       if (hasAllRegions || !seenRegions.insert(region).second)
3325         return emitError(loc, "region '" + name + "' is already bound");
3326     } else if (ctx == RefDirectiveContext && !seenRegions.count(region)) {
3327       return emitError(loc, "region '" + name +
3328                                 "' must be bound before it is referenced");
3329     } else {
3330       return emitError(loc, "regions can only be used at the top level");
3331     }
3332     return create<RegionVariable>(region);
3333   }
3334   // Results.
3335   if (const auto *result = findArg(op.getResults(), name)) {
3336     if (ctx != TypeDirectiveContext)
3337       return emitError(loc, "result variables can can only be used as a child "
3338                             "to a 'type' directive");
3339     return create<ResultVariable>(result);
3340   }
3341   // Successors.
3342   if (const auto *successor = findArg(op.getSuccessors(), name)) {
3343     if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
3344       if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
3345         return emitError(loc, "successor '" + name + "' is already bound");
3346     } else if (ctx == RefDirectiveContext && !seenSuccessors.count(successor)) {
3347       return emitError(loc, "successor '" + name +
3348                                 "' must be bound before it is referenced");
3349     } else {
3350       return emitError(loc, "successors can only be used at the top level");
3351     }
3352 
3353     return create<SuccessorVariable>(successor);
3354   }
3355   return emitError(loc, "expected variable to refer to an argument, region, "
3356                         "result, or successor");
3357 }
3358 
3359 FailureOr<FormatElement *>
3360 OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
3361                                    Context ctx) {
3362   switch (kind) {
3363   case FormatToken::kw_prop_dict:
3364     return parsePropDictDirective(loc, ctx);
3365   case FormatToken::kw_attr_dict:
3366     return parseAttrDictDirective(loc, ctx,
3367                                   /*withKeyword=*/false);
3368   case FormatToken::kw_attr_dict_w_keyword:
3369     return parseAttrDictDirective(loc, ctx,
3370                                   /*withKeyword=*/true);
3371   case FormatToken::kw_functional_type:
3372     return parseFunctionalTypeDirective(loc, ctx);
3373   case FormatToken::kw_operands:
3374     return parseOperandsDirective(loc, ctx);
3375   case FormatToken::kw_regions:
3376     return parseRegionsDirective(loc, ctx);
3377   case FormatToken::kw_results:
3378     return parseResultsDirective(loc, ctx);
3379   case FormatToken::kw_successors:
3380     return parseSuccessorsDirective(loc, ctx);
3381   case FormatToken::kw_type:
3382     return parseTypeDirective(loc, ctx);
3383   case FormatToken::kw_oilist:
3384     return parseOIListDirective(loc, ctx);
3385 
3386   default:
3387     return emitError(loc, "unsupported directive kind");
3388   }
3389 }
3390 
3391 FailureOr<FormatElement *>
3392 OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context,
3393                                        bool withKeyword) {
3394   if (context == TypeDirectiveContext)
3395     return emitError(loc, "'attr-dict' directive can only be used as a "
3396                           "top-level directive");
3397 
3398   if (context == RefDirectiveContext) {
3399     if (!hasAttrDict)
3400       return emitError(loc, "'ref' of 'attr-dict' is not bound by a prior "
3401                             "'attr-dict' directive");
3402 
3403     // Otherwise, this is a top-level context.
3404   } else {
3405     if (hasAttrDict)
3406       return emitError(loc, "'attr-dict' directive has already been seen");
3407     hasAttrDict = true;
3408   }
3409 
3410   return create<AttrDictDirective>(withKeyword);
3411 }
3412 
3413 FailureOr<FormatElement *>
3414 OpFormatParser::parsePropDictDirective(SMLoc loc, Context context) {
3415   if (context == TypeDirectiveContext)
3416     return emitError(loc, "'prop-dict' directive can only be used as a "
3417                           "top-level directive");
3418 
3419   if (context == RefDirectiveContext)
3420     llvm::report_fatal_error("'ref' of 'prop-dict' unsupported");
3421   // Otherwise, this is a top-level context.
3422 
3423   if (hasPropDict)
3424     return emitError(loc, "'prop-dict' directive has already been seen");
3425   hasPropDict = true;
3426 
3427   return create<PropDictDirective>();
3428 }
3429 
3430 LogicalResult OpFormatParser::verifyCustomDirectiveArguments(
3431     SMLoc loc, ArrayRef<FormatElement *> arguments) {
3432   for (FormatElement *argument : arguments) {
3433     if (!isa<AttrDictDirective, PropDictDirective, AttributeVariable,
3434              OperandVariable, PropertyVariable, RefDirective, RegionVariable,
3435              SuccessorVariable, StringElement, TypeDirective>(argument)) {
3436       // TODO: FormatElement should have location info attached.
3437       return emitError(loc, "only variables and types may be used as "
3438                             "parameters to a custom directive");
3439     }
3440     if (auto *type = dyn_cast<TypeDirective>(argument)) {
3441       if (!isa<OperandVariable, ResultVariable>(type->getArg())) {
3442         return emitError(loc, "type directives within a custom directive may "
3443                               "only refer to variables");
3444       }
3445     }
3446   }
3447   return success();
3448 }
3449 
3450 FailureOr<FormatElement *>
3451 OpFormatParser::parseFunctionalTypeDirective(SMLoc loc, Context context) {
3452   if (context != TopLevelContext)
3453     return emitError(
3454         loc, "'functional-type' is only valid as a top-level directive");
3455 
3456   // Parse the main operand.
3457   FailureOr<FormatElement *> inputs, results;
3458   if (failed(parseToken(FormatToken::l_paren,
3459                         "expected '(' before argument list")) ||
3460       failed(inputs = parseTypeDirectiveOperand(loc)) ||
3461       failed(parseToken(FormatToken::comma,
3462                         "expected ',' after inputs argument")) ||
3463       failed(results = parseTypeDirectiveOperand(loc)) ||
3464       failed(
3465           parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3466     return failure();
3467   return create<FunctionalTypeDirective>(*inputs, *results);
3468 }
3469 
3470 FailureOr<FormatElement *>
3471 OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) {
3472   if (context == RefDirectiveContext) {
3473     if (!fmt.allOperands)
3474       return emitError(loc, "'ref' of 'operands' is not bound by a prior "
3475                             "'operands' directive");
3476 
3477   } else if (context == TopLevelContext || context == CustomDirectiveContext) {
3478     if (fmt.allOperands || !seenOperands.empty())
3479       return emitError(loc, "'operands' directive creates overlap in format");
3480     fmt.allOperands = true;
3481   }
3482   return create<OperandsDirective>();
3483 }
3484 
3485 FailureOr<FormatElement *>
3486 OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) {
3487   if (context == TypeDirectiveContext)
3488     return emitError(loc, "'regions' is only valid as a top-level directive");
3489   if (context == RefDirectiveContext) {
3490     if (!hasAllRegions)
3491       return emitError(loc, "'ref' of 'regions' is not bound by a prior "
3492                             "'regions' directive");
3493 
3494     // Otherwise, this is a TopLevel directive.
3495   } else {
3496     if (hasAllRegions || !seenRegions.empty())
3497       return emitError(loc, "'regions' directive creates overlap in format");
3498     hasAllRegions = true;
3499   }
3500   return create<RegionsDirective>();
3501 }
3502 
3503 FailureOr<FormatElement *>
3504 OpFormatParser::parseResultsDirective(SMLoc loc, Context context) {
3505   if (context != TypeDirectiveContext)
3506     return emitError(loc, "'results' directive can can only be used as a child "
3507                           "to a 'type' directive");
3508   return create<ResultsDirective>();
3509 }
3510 
3511 FailureOr<FormatElement *>
3512 OpFormatParser::parseSuccessorsDirective(SMLoc loc, Context context) {
3513   if (context == TypeDirectiveContext)
3514     return emitError(loc,
3515                      "'successors' is only valid as a top-level directive");
3516   if (context == RefDirectiveContext) {
3517     if (!hasAllSuccessors)
3518       return emitError(loc, "'ref' of 'successors' is not bound by a prior "
3519                             "'successors' directive");
3520 
3521     // Otherwise, this is a TopLevel directive.
3522   } else {
3523     if (hasAllSuccessors || !seenSuccessors.empty())
3524       return emitError(loc, "'successors' directive creates overlap in format");
3525     hasAllSuccessors = true;
3526   }
3527   return create<SuccessorsDirective>();
3528 }
3529 
3530 FailureOr<FormatElement *>
3531 OpFormatParser::parseOIListDirective(SMLoc loc, Context context) {
3532   if (failed(parseToken(FormatToken::l_paren,
3533                         "expected '(' before oilist argument list")))
3534     return failure();
3535   std::vector<FormatElement *> literalElements;
3536   std::vector<std::vector<FormatElement *>> parsingElements;
3537   do {
3538     FailureOr<FormatElement *> lelement = parseLiteral(context);
3539     if (failed(lelement))
3540       return failure();
3541     literalElements.push_back(*lelement);
3542     parsingElements.emplace_back();
3543     std::vector<FormatElement *> &currParsingElements = parsingElements.back();
3544     while (peekToken().getKind() != FormatToken::pipe &&
3545            peekToken().getKind() != FormatToken::r_paren) {
3546       FailureOr<FormatElement *> pelement = parseElement(context);
3547       if (failed(pelement) ||
3548           failed(verifyOIListParsingElement(*pelement, loc)))
3549         return failure();
3550       currParsingElements.push_back(*pelement);
3551     }
3552     if (peekToken().getKind() == FormatToken::pipe) {
3553       consumeToken();
3554       continue;
3555     }
3556     if (peekToken().getKind() == FormatToken::r_paren) {
3557       consumeToken();
3558       break;
3559     }
3560   } while (true);
3561 
3562   return create<OIListElement>(std::move(literalElements),
3563                                std::move(parsingElements));
3564 }
3565 
3566 LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element,
3567                                                          SMLoc loc) {
3568   SmallVector<VariableElement *> vars;
3569   collect(element, vars);
3570   for (VariableElement *elem : vars) {
3571     LogicalResult res =
3572         TypeSwitch<FormatElement *, LogicalResult>(elem)
3573             // Only optional attributes can be within an oilist parsing group.
3574             .Case([&](AttributeVariable *attrEle) {
3575               if (!attrEle->getVar()->attr.isOptional() &&
3576                   !attrEle->getVar()->attr.hasDefaultValue())
3577                 return emitError(loc, "only optional attributes can be used in "
3578                                       "an oilist parsing group");
3579               return success();
3580             })
3581             // Only optional properties can be within an oilist parsing group.
3582             .Case([&](PropertyVariable *propEle) {
3583               if (!propEle->getVar()->prop.hasDefaultValue())
3584                 return emitError(
3585                     loc,
3586                     "only default-valued or optional properties can be used in "
3587                     "an olist parsing group");
3588               return success();
3589             })
3590             // Only optional-like(i.e. variadic) operands can be within an
3591             // oilist parsing group.
3592             .Case([&](OperandVariable *ele) {
3593               if (!ele->getVar()->isVariableLength())
3594                 return emitError(loc, "only variable length operands can be "
3595                                       "used within an oilist parsing group");
3596               return success();
3597             })
3598             // Only optional-like(i.e. variadic) results can be within an oilist
3599             // parsing group.
3600             .Case([&](ResultVariable *ele) {
3601               if (!ele->getVar()->isVariableLength())
3602                 return emitError(loc, "only variable length results can be "
3603                                       "used within an oilist parsing group");
3604               return success();
3605             })
3606             .Case([&](RegionVariable *) { return success(); })
3607             .Default([&](FormatElement *) {
3608               return emitError(loc,
3609                                "only literals, types, and variables can be "
3610                                "used within an oilist group");
3611             });
3612     if (failed(res))
3613       return failure();
3614   }
3615   return success();
3616 }
3617 
3618 FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,
3619                                                               Context context) {
3620   if (context == TypeDirectiveContext)
3621     return emitError(loc, "'type' cannot be used as a child of another `type`");
3622 
3623   bool isRefChild = context == RefDirectiveContext;
3624   FailureOr<FormatElement *> operand;
3625   if (failed(parseToken(FormatToken::l_paren,
3626                         "expected '(' before argument list")) ||
3627       failed(operand = parseTypeDirectiveOperand(loc, isRefChild)) ||
3628       failed(
3629           parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3630     return failure();
3631 
3632   return create<TypeDirective>(*operand);
3633 }
3634 
3635 LogicalResult OpFormatParser::markQualified(SMLoc loc, FormatElement *element) {
3636   return TypeSwitch<FormatElement *, LogicalResult>(element)
3637       .Case<AttributeVariable, TypeDirective>([](auto *element) {
3638         element->setShouldBeQualified();
3639         return success();
3640       })
3641       .Default([&](auto *element) {
3642         return this->emitError(
3643             loc,
3644             "'qualified' directive expects an attribute or a `type` directive");
3645       });
3646 }
3647 
3648 FailureOr<FormatElement *>
3649 OpFormatParser::parseTypeDirectiveOperand(SMLoc loc, bool isRefChild) {
3650   FailureOr<FormatElement *> result = parseElement(TypeDirectiveContext);
3651   if (failed(result))
3652     return failure();
3653 
3654   FormatElement *element = *result;
3655   if (isa<LiteralElement>(element))
3656     return emitError(
3657         loc, "'type' directive operand expects variable or directive operand");
3658 
3659   if (auto *var = dyn_cast<OperandVariable>(element)) {
3660     unsigned opIdx = var->getVar() - op.operand_begin();
3661     if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
3662       return emitError(loc, "'type' of '" + var->getVar()->name +
3663                                 "' is already bound");
3664     if (isRefChild && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
3665       return emitError(loc, "'ref' of 'type($" + var->getVar()->name +
3666                                 ")' is not bound by a prior 'type' directive");
3667     seenOperandTypes.set(opIdx);
3668   } else if (auto *var = dyn_cast<ResultVariable>(element)) {
3669     unsigned resIdx = var->getVar() - op.result_begin();
3670     if (!isRefChild && (fmt.allResultTypes || seenResultTypes.test(resIdx)))
3671       return emitError(loc, "'type' of '" + var->getVar()->name +
3672                                 "' is already bound");
3673     if (isRefChild && !(fmt.allResultTypes || seenResultTypes.test(resIdx)))
3674       return emitError(loc, "'ref' of 'type($" + var->getVar()->name +
3675                                 ")' is not bound by a prior 'type' directive");
3676     seenResultTypes.set(resIdx);
3677   } else if (isa<OperandsDirective>(&*element)) {
3678     if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.any()))
3679       return emitError(loc, "'operands' 'type' is already bound");
3680     if (isRefChild && !fmt.allOperandTypes)
3681       return emitError(loc, "'ref' of 'type(operands)' is not bound by a prior "
3682                             "'type' directive");
3683     fmt.allOperandTypes = true;
3684   } else if (isa<ResultsDirective>(&*element)) {
3685     if (!isRefChild && (fmt.allResultTypes || seenResultTypes.any()))
3686       return emitError(loc, "'results' 'type' is already bound");
3687     if (isRefChild && !fmt.allResultTypes)
3688       return emitError(loc, "'ref' of 'type(results)' is not bound by a prior "
3689                             "'type' directive");
3690     fmt.allResultTypes = true;
3691   } else {
3692     return emitError(loc, "invalid argument to 'type' directive");
3693   }
3694   return element;
3695 }
3696 
3697 LogicalResult OpFormatParser::verifyOptionalGroupElements(
3698     SMLoc loc, ArrayRef<FormatElement *> elements, FormatElement *anchor) {
3699   for (FormatElement *element : elements) {
3700     if (failed(verifyOptionalGroupElement(loc, element, element == anchor)))
3701       return failure();
3702   }
3703   return success();
3704 }
3705 
3706 LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc,
3707                                                          FormatElement *element,
3708                                                          bool isAnchor) {
3709   return TypeSwitch<FormatElement *, LogicalResult>(element)
3710       // All attributes can be within the optional group, but only optional
3711       // attributes can be the anchor.
3712       .Case([&](AttributeVariable *attrEle) {
3713         Attribute attr = attrEle->getVar()->attr;
3714         if (isAnchor && !(attr.isOptional() || attr.hasDefaultValue()))
3715           return emitError(loc, "only optional or default-valued attributes "
3716                                 "can be used to anchor an optional group");
3717         return success();
3718       })
3719       // All properties can be within the optional group, but only optional
3720       // properties can be the anchor.
3721       .Case([&](PropertyVariable *propEle) {
3722         Property prop = propEle->getVar()->prop;
3723         if (isAnchor && !(prop.hasDefaultValue() && prop.hasOptionalParser()))
3724           return emitError(loc, "only properties with default values "
3725                                 "that can be optionally parsed "
3726                                 "can be used to anchor an optional group");
3727         return success();
3728       })
3729       // Only optional-like(i.e. variadic) operands can be within an optional
3730       // group.
3731       .Case([&](OperandVariable *ele) {
3732         if (!ele->getVar()->isVariableLength())
3733           return emitError(loc, "only variable length operands can be used "
3734                                 "within an optional group");
3735         return success();
3736       })
3737       // Only optional-like(i.e. variadic) results can be within an optional
3738       // group.
3739       .Case([&](ResultVariable *ele) {
3740         if (!ele->getVar()->isVariableLength())
3741           return emitError(loc, "only variable length results can be used "
3742                                 "within an optional group");
3743         return success();
3744       })
3745       .Case([&](RegionVariable *) {
3746         // TODO: When ODS has proper support for marking "optional" regions, add
3747         // a check here.
3748         return success();
3749       })
3750       .Case([&](TypeDirective *ele) {
3751         return verifyOptionalGroupElement(loc, ele->getArg(),
3752                                           /*isAnchor=*/false);
3753       })
3754       .Case([&](FunctionalTypeDirective *ele) {
3755         if (failed(verifyOptionalGroupElement(loc, ele->getInputs(),
3756                                               /*isAnchor=*/false)))
3757           return failure();
3758         return verifyOptionalGroupElement(loc, ele->getResults(),
3759                                           /*isAnchor=*/false);
3760       })
3761       .Case([&](CustomDirective *ele) {
3762         if (!isAnchor)
3763           return success();
3764         // Verify each child as being valid in an optional group. They are all
3765         // potential anchors if the custom directive was marked as one.
3766         for (FormatElement *child : ele->getArguments()) {
3767           if (isa<RefDirective>(child))
3768             continue;
3769           if (failed(verifyOptionalGroupElement(loc, child, /*isAnchor=*/true)))
3770             return failure();
3771         }
3772         return success();
3773       })
3774       // Literals, whitespace, and custom directives may be used, but they can't
3775       // anchor the group.
3776       .Case<LiteralElement, WhitespaceElement, OptionalElement>(
3777           [&](FormatElement *) {
3778             if (isAnchor)
3779               return emitError(loc, "only variables and types can be used "
3780                                     "to anchor an optional group");
3781             return success();
3782           })
3783       .Default([&](FormatElement *) {
3784         return emitError(loc, "only literals, types, and variables can be "
3785                               "used within an optional group");
3786       });
3787 }
3788 
3789 //===----------------------------------------------------------------------===//
3790 // Interface
3791 //===----------------------------------------------------------------------===//
3792 
3793 void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass,
3794                                     bool hasProperties) {
3795   // TODO: Operator doesn't expose all necessary functionality via
3796   // the const interface.
3797   Operator &op = const_cast<Operator &>(constOp);
3798   if (!op.hasAssemblyFormat())
3799     return;
3800 
3801   // Parse the format description.
3802   llvm::SourceMgr mgr;
3803   mgr.AddNewSourceBuffer(
3804       llvm::MemoryBuffer::getMemBuffer(op.getAssemblyFormat()), SMLoc());
3805   OperationFormat format(op, hasProperties);
3806   OpFormatParser parser(mgr, format, op);
3807   FailureOr<std::vector<FormatElement *>> elements = parser.parse();
3808   if (failed(elements)) {
3809     // Exit the process if format errors are treated as fatal.
3810     if (formatErrorIsFatal) {
3811       // Invoke the interrupt handlers to run the file cleanup handlers.
3812       llvm::sys::RunInterruptHandlers();
3813       std::exit(1);
3814     }
3815     return;
3816   }
3817   format.elements = std::move(*elements);
3818 
3819   // Generate the printer and parser based on the parsed format.
3820   format.genParser(op, opClass);
3821   format.genPrinter(op, opClass);
3822 }
3823