xref: /llvm-project/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp (revision 7359a6b7996f92e6659418d3d2e5b57c44d65e37)
1 //===- AttrOrTypeFormatGen.cpp - MLIR attribute and type format generator -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "AttrOrTypeFormatGen.h"
10 #include "FormatGen.h"
11 #include "mlir/Support/LLVM.h"
12 #include "mlir/TableGen/AttrOrTypeDef.h"
13 #include "mlir/TableGen/Format.h"
14 #include "mlir/TableGen/GenInfo.h"
15 #include "llvm/ADT/BitVector.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/ADT/StringSwitch.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/MemoryBuffer.h"
20 #include "llvm/Support/SaveAndRestore.h"
21 #include "llvm/Support/SourceMgr.h"
22 #include "llvm/TableGen/Error.h"
23 #include "llvm/TableGen/TableGenBackend.h"
24 
25 using namespace mlir;
26 using namespace mlir::tblgen;
27 
28 using llvm::formatv;
29 
30 //===----------------------------------------------------------------------===//
31 // Element
32 //===----------------------------------------------------------------------===//
33 
34 namespace {
35 /// This class represents an instance of a variable element. A variable refers
36 /// to an attribute or type parameter.
37 class ParameterElement
38     : public VariableElementBase<VariableElement::Parameter> {
39 public:
40   ParameterElement(AttrOrTypeParameter param) : param(param) {}
41 
42   /// Get the parameter in the element.
43   const AttrOrTypeParameter &getParam() const { return param; }
44 
45   /// Indicate if this variable is printed "qualified" (that is it is
46   /// prefixed with the `#dialect.mnemonic`).
47   bool shouldBeQualified() { return shouldBeQualifiedFlag; }
48   void setShouldBeQualified(bool qualified = true) {
49     shouldBeQualifiedFlag = qualified;
50   }
51 
52   /// Returns true if the element contains an optional parameter.
53   bool isOptional() const { return param.isOptional(); }
54 
55   /// Returns the name of the parameter.
56   StringRef getName() const { return param.getName(); }
57 
58   /// Return the code to check whether the parameter is present.
59   auto genIsPresent(FmtContext &ctx, const Twine &self) const {
60     assert(isOptional() && "cannot guard on a mandatory parameter");
61     std::string valueStr = tgfmt(*param.getDefaultValue(), &ctx).str();
62     ctx.addSubst("_lhs", self).addSubst("_rhs", valueStr);
63     return tgfmt(getParam().getComparator(), &ctx);
64   }
65 
66   /// Generate the code to check whether the parameter should be printed.
67   MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const {
68     assert(isOptional() && "cannot guard on a mandatory parameter");
69     std::string self = param.getAccessorName() + "()";
70     return os << "!(" << genIsPresent(ctx, self) << ")";
71   }
72 
73 private:
74   bool shouldBeQualifiedFlag = false;
75   AttrOrTypeParameter param;
76 };
77 
78 /// Shorthand functions that can be used with ranged-based conditions.
79 static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); }
80 static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); }
81 
82 /// Base class for a directive that contains references to multiple variables.
83 template <DirectiveElement::Kind DirectiveKind>
84 class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> {
85 public:
86   using Base = ParamsDirectiveBase<DirectiveKind>;
87 
88   ParamsDirectiveBase(std::vector<ParameterElement *> &&params)
89       : params(std::move(params)) {}
90 
91   /// Get the parameters contained in this directive.
92   ArrayRef<ParameterElement *> getParams() const { return params; }
93 
94   /// Get the number of parameters.
95   unsigned getNumParams() const { return params.size(); }
96 
97   /// Take all of the parameters from this directive.
98   std::vector<ParameterElement *> takeParams() { return std::move(params); }
99 
100   /// Returns true if there are optional parameters present.
101   bool hasOptionalParams() const {
102     return llvm::any_of(getParams(), paramIsOptional);
103   }
104 
105 private:
106   /// The parameters captured by this directive.
107   std::vector<ParameterElement *> params;
108 };
109 
110 /// This class represents a `params` directive that refers to all parameters
111 /// of an attribute or type. When used as a top-level directive, it generates
112 /// a format of the form:
113 ///
114 ///   (param-value (`,` param-value)*)?
115 ///
116 /// When used as an argument to another directive that accepts variables,
117 /// `params` can be used in place of manually listing all parameters of an
118 /// attribute or type.
119 class ParamsDirective : public ParamsDirectiveBase<DirectiveElement::Params> {
120 public:
121   using Base::Base;
122 };
123 
124 /// This class represents a `struct` directive that generates a struct format
125 /// of the form:
126 ///
127 ///   `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
128 ///
129 class StructDirective : public ParamsDirectiveBase<DirectiveElement::Struct> {
130 public:
131   using Base::Base;
132 };
133 
134 } // namespace
135 
136 //===----------------------------------------------------------------------===//
137 // Format Strings
138 //===----------------------------------------------------------------------===//
139 
140 /// Default parser for attribute or type parameters.
141 static const char *const defaultParameterParser =
142     "::mlir::FieldParser<$0>::parse($_parser)";
143 
144 /// Default printer for attribute or type parameters.
145 static const char *const defaultParameterPrinter =
146     "$_printer.printStrippedAttrOrType($_self)";
147 
148 /// Qualified printer for attribute or type parameters: it does not elide
149 /// dialect and mnemonic.
150 static const char *const qualifiedParameterPrinter = "$_printer << $_self";
151 
152 /// Print an error when failing to parse an element.
153 ///
154 /// $0: The parameter C++ class name.
155 static const char *const parserErrorStr =
156     "$_parser.emitError($_parser.getCurrentLocation(), ";
157 
158 /// Code format to parse a variable. Separate by lines because variable parsers
159 /// may be generated inside other directives, which requires indentation.
160 ///
161 /// {0}: The parameter name.
162 /// {1}: The parse code for the parameter.
163 /// {2}: Code template for printing an error.
164 /// {3}: Name of the attribute or type.
165 /// {4}: C++ class of the parameter.
166 /// {5}: Optional code to preload the dialect for this variable.
167 static const char *const variableParser = R"(
168 // Parse variable '{0}'{5}
169 _result_{0} = {1};
170 if (::mlir::failed(_result_{0})) {{
171   {2}"failed to parse {3} parameter '{0}' which is to be a `{4}`");
172   return {{};
173 }
174 )";
175 
176 //===----------------------------------------------------------------------===//
177 // DefFormat
178 //===----------------------------------------------------------------------===//
179 
180 namespace {
181 class DefFormat {
182 public:
183   DefFormat(const AttrOrTypeDef &def, std::vector<FormatElement *> &&elements)
184       : def(def), elements(std::move(elements)) {}
185 
186   /// Generate the attribute or type parser.
187   void genParser(MethodBody &os);
188   /// Generate the attribute or type printer.
189   void genPrinter(MethodBody &os);
190 
191 private:
192   /// Generate the parser code for a specific format element.
193   void genElementParser(FormatElement *el, FmtContext &ctx, MethodBody &os);
194   /// Generate the parser code for a literal.
195   void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os,
196                         bool isOptional = false);
197   /// Generate the parser code for a variable.
198   void genVariableParser(ParameterElement *el, FmtContext &ctx, MethodBody &os);
199   /// Generate the parser code for a `params` directive.
200   void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
201   /// Generate the parser code for a `struct` directive.
202   void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os);
203   /// Generate the parser code for a `custom` directive.
204   void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os,
205                        bool isOptional = false);
206   /// Generate the parser code for an optional group.
207   void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
208                               MethodBody &os);
209 
210   /// Generate the printer code for a specific format element.
211   void genElementPrinter(FormatElement *el, FmtContext &ctx, MethodBody &os);
212   /// Generate the printer code for a literal.
213   void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os);
214   /// Generate the printer code for a variable.
215   void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os,
216                           bool skipGuard = false);
217   /// Generate a printer for comma-separated parameters.
218   void genCommaSeparatedPrinter(ArrayRef<ParameterElement *> params,
219                                 FmtContext &ctx, MethodBody &os,
220                                 function_ref<void(ParameterElement *)> extra);
221   /// Generate the printer code for a `params` directive.
222   void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
223   /// Generate the printer code for a `struct` directive.
224   void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os);
225   /// Generate the printer code for a `custom` directive.
226   void genCustomPrinter(CustomDirective *el, FmtContext &ctx, MethodBody &os);
227   /// Generate the printer code for an optional group.
228   void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
229                                MethodBody &os);
230   /// Generate a printer (or space eraser) for a whitespace element.
231   void genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx,
232                             MethodBody &os);
233 
234   /// The ODS definition of the attribute or type whose format is being used to
235   /// generate a parser and printer.
236   const AttrOrTypeDef &def;
237   /// The list of top-level format elements returned by the assembly format
238   /// parser.
239   std::vector<FormatElement *> elements;
240 
241   /// Flags for printing spaces.
242   bool shouldEmitSpace = false;
243   bool lastWasPunctuation = false;
244 };
245 } // namespace
246 
247 //===----------------------------------------------------------------------===//
248 // ParserGen
249 //===----------------------------------------------------------------------===//
250 
251 /// Generate a special-case "parser" for an attribute's self type parameter. The
252 /// self type parameter has special handling in the assembly format in that it
253 /// is derived from the optional trailing colon type after the attribute.
254 static void genAttrSelfTypeParser(MethodBody &os, const FmtContext &ctx,
255                                   const AttributeSelfTypeParameter &param) {
256   // "Parser" for an attribute self type parameter that checks the
257   // optionally-parsed trailing colon type.
258   //
259   // $0: The C++ storage class of the type parameter.
260   // $1: The self type parameter name.
261   const char *const selfTypeParser = R"(
262 if ($_type) {
263   if (auto reqType = ::llvm::dyn_cast<$0>($_type)) {
264     _result_$1 = reqType;
265   } else {
266     $_parser.emitError($_loc, "invalid kind of type specified");
267     return {};
268   }
269 })";
270 
271   // If the attribute self type parameter is required, emit code that emits an
272   // error if the trailing type was not parsed.
273   const char *const selfTypeRequired = R"( else {
274   $_parser.emitError($_loc, "expected a trailing type");
275   return {};
276 })";
277 
278   os << tgfmt(selfTypeParser, &ctx, param.getCppStorageType(), param.getName());
279   if (!param.isOptional())
280     os << tgfmt(selfTypeRequired, &ctx);
281   os << "\n";
282 }
283 
284 void DefFormat::genParser(MethodBody &os) {
285   FmtContext ctx;
286   ctx.addSubst("_parser", "odsParser");
287   ctx.addSubst("_ctxt", "odsParser.getContext()");
288   ctx.withBuilder("odsBuilder");
289   if (isa<AttrDef>(def))
290     ctx.addSubst("_type", "odsType");
291   os.indent();
292   os << "::mlir::Builder odsBuilder(odsParser.getContext());\n";
293 
294   // Store the initial location of the parser.
295   ctx.addSubst("_loc", "odsLoc");
296   os << tgfmt("::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n"
297               "(void) $_loc;\n",
298               &ctx);
299 
300   // Declare variables to store all of the parameters. Allocated parameters
301   // such as `ArrayRef` and `StringRef` must provide a `storageType`. Store
302   // FailureOr<T> to defer type construction for parameters that are parsed in
303   // a loop (parsers return FailureOr anyways).
304   ArrayRef<AttrOrTypeParameter> params = def.getParameters();
305   for (const AttrOrTypeParameter &param : params) {
306     os << formatv("::mlir::FailureOr<{0}> _result_{1};\n",
307                   param.getCppStorageType(), param.getName());
308     if (auto *selfTypeParam = dyn_cast<AttributeSelfTypeParameter>(&param))
309       genAttrSelfTypeParser(os, ctx, *selfTypeParam);
310   }
311 
312   // Generate call to each parameter parser.
313   for (FormatElement *el : elements)
314     genElementParser(el, ctx, os);
315 
316   // Emit an assert for each mandatory parameter. Triggering an assert means
317   // the generated parser is incorrect (i.e. there is a bug in this code).
318   for (const AttrOrTypeParameter &param : params) {
319     if (param.isOptional())
320       continue;
321     os << formatv("assert(::mlir::succeeded(_result_{0}));\n", param.getName());
322   }
323 
324   // Generate call to the attribute or type builder. Use the checked getter
325   // if one was generated.
326   if (def.genVerifyDecl() || def.genVerifyInvariantsImpl()) {
327     os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
328                 &ctx, def.getCppClassName());
329   } else {
330     os << tgfmt("return $0::get($_parser.getContext()", &ctx,
331                 def.getCppClassName());
332   }
333   for (const AttrOrTypeParameter &param : params) {
334     os << ",\n    ";
335     std::string paramSelfStr;
336     llvm::raw_string_ostream selfOs(paramSelfStr);
337     if (std::optional<StringRef> defaultValue = param.getDefaultValue()) {
338       selfOs << formatv("(_result_{0}.value_or(", param.getName())
339              << tgfmt(*defaultValue, &ctx) << "))";
340     } else {
341       selfOs << formatv("(*_result_{0})", param.getName());
342     }
343     ctx.addSubst(param.getName(), selfOs.str());
344     os << param.getCppType() << "("
345        << tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str()))
346        << ")";
347   }
348   os << ");";
349 }
350 
351 void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx,
352                                  MethodBody &os) {
353   if (auto *literal = dyn_cast<LiteralElement>(el))
354     return genLiteralParser(literal->getSpelling(), ctx, os);
355   if (auto *var = dyn_cast<ParameterElement>(el))
356     return genVariableParser(var, ctx, os);
357   if (auto *params = dyn_cast<ParamsDirective>(el))
358     return genParamsParser(params, ctx, os);
359   if (auto *strct = dyn_cast<StructDirective>(el))
360     return genStructParser(strct, ctx, os);
361   if (auto *custom = dyn_cast<CustomDirective>(el))
362     return genCustomParser(custom, ctx, os);
363   if (auto *optional = dyn_cast<OptionalElement>(el))
364     return genOptionalGroupParser(optional, ctx, os);
365   if (isa<WhitespaceElement>(el))
366     return;
367 
368   llvm_unreachable("unknown format element");
369 }
370 
371 void DefFormat::genLiteralParser(StringRef value, FmtContext &ctx,
372                                  MethodBody &os, bool isOptional) {
373   os << "// Parse literal '" << value << "'\n";
374   os << tgfmt("if ($_parser.parse", &ctx);
375   if (isOptional)
376     os << "Optional";
377   if (value.front() == '_' || isalpha(value.front())) {
378     os << "Keyword(\"" << value << "\")";
379   } else {
380     os << StringSwitch<StringRef>(value)
381               .Case("->", "Arrow")
382               .Case(":", "Colon")
383               .Case(",", "Comma")
384               .Case("=", "Equal")
385               .Case("<", "Less")
386               .Case(">", "Greater")
387               .Case("{", "LBrace")
388               .Case("}", "RBrace")
389               .Case("(", "LParen")
390               .Case(")", "RParen")
391               .Case("[", "LSquare")
392               .Case("]", "RSquare")
393               .Case("?", "Question")
394               .Case("+", "Plus")
395               .Case("*", "Star")
396               .Case("...", "Ellipsis")
397        << "()";
398   }
399   if (isOptional) {
400     // Leave the `if` unclosed to guard optional groups.
401     return;
402   }
403   // Parser will emit an error
404   os << ") return {};\n";
405 }
406 
407 void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
408                                   MethodBody &os) {
409   // Check for a custom parser. Use the default attribute parser otherwise.
410   const AttrOrTypeParameter &param = el->getParam();
411   auto customParser = param.getParser();
412   auto parser =
413       customParser ? *customParser : StringRef(defaultParameterParser);
414 
415   // If the variable points to a dialect specific entity (type of attribute),
416   // we force load the dialect now before trying to parse it.
417   std::string dialectLoading;
418   if (auto *defInit = dyn_cast<llvm::DefInit>(param.getDef())) {
419     auto *dialectValue = defInit->getDef()->getValue("dialect");
420     if (dialectValue) {
421       if (auto *dialectInit =
422               dyn_cast<llvm::DefInit>(dialectValue->getValue())) {
423         Dialect dialect(dialectInit->getDef());
424         auto cppNamespace = dialect.getCppNamespace();
425         std::string name = dialect.getCppClassName();
426         if (name != "BuiltinDialect" || cppNamespace != "::mlir") {
427           dialectLoading = ("\nodsParser.getContext()->getOrLoadDialect<" +
428                             cppNamespace + "::" + name + ">();")
429                                .str();
430         }
431       }
432     }
433   }
434   os << formatv(variableParser, param.getName(),
435                 tgfmt(parser, &ctx, param.getCppStorageType()),
436                 tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType(),
437                 dialectLoading);
438 }
439 
440 void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
441                                 MethodBody &os) {
442   os << "// Parse parameter list\n";
443 
444   // If there are optional parameters, we need to switch to `parseOptionalComma`
445   // if there are no more required parameters after a certain point.
446   bool hasOptional = el->hasOptionalParams();
447   if (hasOptional) {
448     // Wrap everything in a do-while so that we can `break`.
449     os << "do {\n";
450     os.indent();
451   }
452 
453   ArrayRef<ParameterElement *> params = el->getParams();
454   using IteratorT = ParameterElement *const *;
455   IteratorT it = params.begin();
456 
457   // Find the last required parameter. Commas become optional aftewards.
458   // Note: IteratorT's copy assignment is deleted.
459   ParameterElement *lastReq = nullptr;
460   for (ParameterElement *param : params)
461     if (!param->isOptional())
462       lastReq = param;
463   IteratorT lastReqIt = lastReq ? llvm::find(params, lastReq) : params.begin();
464 
465   auto eachFn = [&](ParameterElement *el) { genVariableParser(el, ctx, os); };
466   auto betweenFn = [&](IteratorT it) {
467     ParameterElement *el = *std::prev(it);
468     // Parse a comma if the last optional parameter had a value.
469     if (el->isOptional()) {
470       os << formatv("if (::mlir::succeeded(_result_{0}) && !({1})) {{\n",
471                     el->getName(),
472                     el->genIsPresent(ctx, "(*_result_" + el->getName() + ")"));
473       os.indent();
474     }
475     if (it <= lastReqIt) {
476       genLiteralParser(",", ctx, os);
477     } else {
478       genLiteralParser(",", ctx, os, /*isOptional=*/true);
479       os << ") break;\n";
480     }
481     if (el->isOptional())
482       os.unindent() << "}\n";
483   };
484 
485   // llvm::interleave
486   if (it != params.end()) {
487     eachFn(*it++);
488     for (IteratorT e = params.end(); it != e; ++it) {
489       betweenFn(it);
490       eachFn(*it);
491     }
492   }
493 
494   if (hasOptional)
495     os.unindent() << "} while(false);\n";
496 }
497 
498 void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
499                                 MethodBody &os) {
500   // Loop declaration for struct parser with only required parameters.
501   //
502   // $0: Number of expected parameters.
503   const char *const loopHeader = R"(
504   for (unsigned odsStructIndex = 0; odsStructIndex < $0; ++odsStructIndex) {
505 )";
506 
507   // Loop body start for struct parser.
508   const char *const loopStart = R"(
509     ::llvm::StringRef _paramKey;
510     if ($_parser.parseKeyword(&_paramKey)) {
511       $_parser.emitError($_parser.getCurrentLocation(),
512                          "expected a parameter name in struct");
513       return {};
514     }
515     if (!_loop_body(_paramKey)) return {};
516 )";
517 
518   // Struct parser loop end. Check for duplicate or unknown struct parameters.
519   //
520   // {0}: Code template for printing an error.
521   const char *const loopEnd = R"({{
522   {0}"duplicate or unknown struct parameter name: ") << _paramKey;
523   return {{};
524 }
525 )";
526 
527   // Struct parser loop terminator. Parse a comma except on the last element.
528   //
529   // {0}: Number of elements in the struct.
530   const char *const loopTerminator = R"(
531   if ((odsStructIndex != {0} - 1) && odsParser.parseComma())
532     return {{};
533 }
534 )";
535 
536   // Check that a mandatory parameter was parse.
537   //
538   // {0}: Name of the parameter.
539   const char *const checkParam = R"(
540     if (!_seen_{0}) {
541       {1}"struct is missing required parameter: ") << "{0}";
542       return {{};
543     }
544 )";
545 
546   // First iteration of the loop parsing an optional struct.
547   const char *const optionalStructFirst = R"(
548   ::llvm::StringRef _paramKey;
549   if (!$_parser.parseOptionalKeyword(&_paramKey)) {
550     if (!_loop_body(_paramKey)) return {};
551     while (!$_parser.parseOptionalComma()) {
552 )";
553 
554   os << "// Parse parameter struct\n";
555 
556   // Declare a "seen" variable for each key.
557   for (ParameterElement *param : el->getParams())
558     os << formatv("bool _seen_{0} = false;\n", param->getName());
559 
560   // Generate the body of the parsing loop inside a lambda.
561   os << "{\n";
562   os.indent()
563       << "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n";
564   genLiteralParser("=", ctx, os.indent());
565   for (ParameterElement *param : el->getParams()) {
566     os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n"
567                   "  _seen_{0} = true;\n",
568                   param->getName());
569     genVariableParser(param, ctx, os.indent());
570     os.unindent() << "} else ";
571     // Print the check for duplicate or unknown parameter.
572   }
573   os.getStream().printReindented(strfmt(loopEnd, tgfmt(parserErrorStr, &ctx)));
574   os << "return true;\n";
575   os.unindent() << "};\n";
576 
577   // Generate the parsing loop. If optional parameters are present, then the
578   // parse loop is guarded by commas.
579   unsigned numOptional = llvm::count_if(el->getParams(), paramIsOptional);
580   if (numOptional) {
581     // If the struct itself is optional, pull out the first iteration.
582     if (numOptional == el->getNumParams()) {
583       os.getStream().printReindented(tgfmt(optionalStructFirst, &ctx).str());
584       os.indent();
585     } else {
586       os << "do {\n";
587     }
588   } else {
589     os.getStream().printReindented(
590         tgfmt(loopHeader, &ctx, el->getNumParams()).str());
591   }
592   os.indent();
593   os.getStream().printReindented(tgfmt(loopStart, &ctx).str());
594   os.unindent();
595 
596   // Print the loop terminator. For optional parameters, we have to check that
597   // all mandatory parameters have been parsed.
598   // The whole struct is optional if all its parameters are optional.
599   if (numOptional) {
600     if (numOptional == el->getNumParams()) {
601       os << "}\n";
602       os.unindent() << "}\n";
603     } else {
604       os << tgfmt("} while(!$_parser.parseOptionalComma());\n", &ctx);
605       for (ParameterElement *param : el->getParams()) {
606         if (param->isOptional())
607           continue;
608         os.getStream().printReindented(
609             strfmt(checkParam, param->getName(), tgfmt(parserErrorStr, &ctx)));
610       }
611     }
612   } else {
613     // Because the loop loops N times and each non-failing iteration sets 1 of
614     // N flags, successfully exiting the loop means that all parameters have
615     // been seen. `parseOptionalComma` would cause issues with any formats that
616     // use "struct(...) `,`" beacuse structs aren't sounded by braces.
617     os.getStream().printReindented(strfmt(loopTerminator, el->getNumParams()));
618   }
619   os.unindent() << "}\n";
620 }
621 
622 void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
623                                 MethodBody &os, bool isOptional) {
624   os << "{\n";
625   os.indent();
626 
627   // Bound variables are passed directly to the parser as `FailureOr<T> &`.
628   // Referenced variables are passed as `T`. The custom parser fails if it
629   // returns failure or if any of the required parameters failed.
630   os << tgfmt("auto odsCustomLoc = $_parser.getCurrentLocation();\n", &ctx);
631   os << "(void)odsCustomLoc;\n";
632   os << tgfmt("auto odsCustomResult = parse$0($_parser", &ctx, el->getName());
633   os.indent();
634   for (FormatElement *arg : el->getArguments()) {
635     os << ",\n";
636     if (auto *param = dyn_cast<ParameterElement>(arg))
637       os << "::mlir::detail::unwrapForCustomParse(_result_" << param->getName()
638          << ")";
639     else if (auto *ref = dyn_cast<RefDirective>(arg))
640       os << "*_result_" << cast<ParameterElement>(ref->getArg())->getName();
641     else
642       os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx);
643   }
644   os.unindent() << ");\n";
645   if (isOptional) {
646     os << "if (!odsCustomResult.has_value()) return {};\n";
647     os << "if (::mlir::failed(*odsCustomResult)) return ::mlir::failure();\n";
648   } else {
649     os << "if (::mlir::failed(odsCustomResult)) return {};\n";
650   }
651   for (FormatElement *arg : el->getArguments()) {
652     if (auto *param = dyn_cast<ParameterElement>(arg)) {
653       if (param->isOptional())
654         continue;
655       os << formatv("if (::mlir::failed(_result_{0})) {{\n", param->getName());
656       os.indent() << tgfmt("$_parser.emitError(odsCustomLoc, ", &ctx)
657                   << "\"custom parser failed to parse parameter '"
658                   << param->getName() << "'\");\n";
659       os << "return " << (isOptional ? "::mlir::failure()" : "{}") << ";\n";
660       os.unindent() << "}\n";
661     }
662   }
663 
664   os.unindent() << "}\n";
665 }
666 
667 void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
668                                        MethodBody &os) {
669   ArrayRef<FormatElement *> thenElements =
670       el->getThenElements(/*parseable=*/true);
671 
672   FormatElement *first = thenElements.front();
673   const auto guardOn = [&](auto params) {
674     os << "if (!(";
675     llvm::interleave(
676         params, os,
677         [&](ParameterElement *el) {
678           os << formatv("(::mlir::succeeded(_result_{0}) && *_result_{0})",
679                         el->getName());
680         },
681         " || ");
682     os << ")) {\n";
683   };
684   if (auto *literal = dyn_cast<LiteralElement>(first)) {
685     genLiteralParser(literal->getSpelling(), ctx, os, /*isOptional=*/true);
686     os << ") {\n";
687   } else if (auto *param = dyn_cast<ParameterElement>(first)) {
688     genVariableParser(param, ctx, os);
689     guardOn(llvm::ArrayRef(param));
690   } else if (auto *params = dyn_cast<ParamsDirective>(first)) {
691     genParamsParser(params, ctx, os);
692     guardOn(params->getParams());
693   } else if (auto *custom = dyn_cast<CustomDirective>(first)) {
694     os << "if (auto result = [&]() -> ::mlir::OptionalParseResult {\n";
695     os.indent();
696     genCustomParser(custom, ctx, os, /*isOptional=*/true);
697     os << "return ::mlir::success();\n";
698     os.unindent();
699     os << "}(); result.has_value() && ::mlir::failed(*result)) {\n";
700     os.indent();
701     os << "return {};\n";
702     os.unindent();
703     os << "} else if (result.has_value()) {\n";
704   } else {
705     auto *strct = cast<StructDirective>(first);
706     genStructParser(strct, ctx, os);
707     guardOn(params->getParams());
708   }
709   os.indent();
710 
711   // Generate the parsers for the rest of the thenElements.
712   for (FormatElement *element : el->getElseElements(/*parseable=*/true))
713     genElementParser(element, ctx, os);
714   os.unindent() << "} else {\n";
715   os.indent();
716   for (FormatElement *element : thenElements.drop_front())
717     genElementParser(element, ctx, os);
718   os.unindent() << "}\n";
719 }
720 
721 //===----------------------------------------------------------------------===//
722 // PrinterGen
723 //===----------------------------------------------------------------------===//
724 
725 void DefFormat::genPrinter(MethodBody &os) {
726   FmtContext ctx;
727   ctx.addSubst("_printer", "odsPrinter");
728   ctx.addSubst("_ctxt", "getContext()");
729   ctx.withBuilder("odsBuilder");
730   os.indent();
731   os << "::mlir::Builder odsBuilder(getContext());\n";
732 
733   // Generate printers.
734   shouldEmitSpace = true;
735   lastWasPunctuation = false;
736   for (FormatElement *el : elements)
737     genElementPrinter(el, ctx, os);
738 }
739 
740 void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
741                                   MethodBody &os) {
742   if (auto *literal = dyn_cast<LiteralElement>(el))
743     return genLiteralPrinter(literal->getSpelling(), ctx, os);
744   if (auto *params = dyn_cast<ParamsDirective>(el))
745     return genParamsPrinter(params, ctx, os);
746   if (auto *strct = dyn_cast<StructDirective>(el))
747     return genStructPrinter(strct, ctx, os);
748   if (auto *custom = dyn_cast<CustomDirective>(el))
749     return genCustomPrinter(custom, ctx, os);
750   if (auto *var = dyn_cast<ParameterElement>(el))
751     return genVariablePrinter(var, ctx, os);
752   if (auto *optional = dyn_cast<OptionalElement>(el))
753     return genOptionalGroupPrinter(optional, ctx, os);
754   if (auto *whitespace = dyn_cast<WhitespaceElement>(el))
755     return genWhitespacePrinter(whitespace, ctx, os);
756 
757   llvm::PrintFatalError("unsupported format element");
758 }
759 
760 void DefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
761                                   MethodBody &os) {
762   // Don't insert a space before certain punctuation.
763   bool needSpace =
764       shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation);
765   os << tgfmt("$_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "",
766               value);
767 
768   // Update the flags.
769   shouldEmitSpace =
770       value.size() != 1 || !StringRef("<({[").contains(value.front());
771   lastWasPunctuation = value.front() != '_' && !isalpha(value.front());
772 }
773 
774 void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
775                                    MethodBody &os, bool skipGuard) {
776   const AttrOrTypeParameter &param = el->getParam();
777   ctx.withSelf(param.getAccessorName() + "()");
778 
779   // Guard the printer on the presence of optional parameters and that they
780   // aren't equal to their default values (if they have one).
781   if (el->isOptional() && !skipGuard) {
782     el->genPrintGuard(ctx, os << "if (") << ") {\n";
783     os.indent();
784   }
785 
786   // Insert a space before the next parameter, if necessary.
787   if (shouldEmitSpace || !lastWasPunctuation)
788     os << tgfmt("$_printer << ' ';\n", &ctx);
789   shouldEmitSpace = true;
790   lastWasPunctuation = false;
791 
792   if (el->shouldBeQualified())
793     os << tgfmt(qualifiedParameterPrinter, &ctx) << ";\n";
794   else if (auto printer = param.getPrinter())
795     os << tgfmt(*printer, &ctx) << ";\n";
796   else
797     os << tgfmt(defaultParameterPrinter, &ctx) << ";\n";
798 
799   if (el->isOptional() && !skipGuard)
800     os.unindent() << "}\n";
801 }
802 
803 /// Generate code to guard printing on the presence of any optional parameters.
804 template <typename ParameterRange>
805 static void guardOnAny(FmtContext &ctx, MethodBody &os, ParameterRange &&params,
806                        bool inverted = false) {
807   os << "if (";
808   if (inverted)
809     os << "!(";
810   llvm::interleave(
811       params, os,
812       [&](ParameterElement *param) { param->genPrintGuard(ctx, os); }, " || ");
813   if (inverted)
814     os << ")";
815   os << ") {\n";
816   os.indent();
817 }
818 
819 void DefFormat::genCommaSeparatedPrinter(
820     ArrayRef<ParameterElement *> params, FmtContext &ctx, MethodBody &os,
821     function_ref<void(ParameterElement *)> extra) {
822   // Emit a space if necessary, but only if the struct is present.
823   if (shouldEmitSpace || !lastWasPunctuation) {
824     bool allOptional = llvm::all_of(params, paramIsOptional);
825     if (allOptional)
826       guardOnAny(ctx, os, params);
827     os << tgfmt("$_printer << ' ';\n", &ctx);
828     if (allOptional)
829       os.unindent() << "}\n";
830   }
831 
832   // The first printed element does not need to emit a comma.
833   os << "{\n";
834   os.indent() << "bool _firstPrinted = true;\n";
835   for (ParameterElement *param : params) {
836     if (param->isOptional()) {
837       param->genPrintGuard(ctx, os << "if (") << ") {\n";
838       os.indent();
839     }
840     os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx);
841     os << "_firstPrinted = false;\n";
842     extra(param);
843     shouldEmitSpace = false;
844     lastWasPunctuation = true;
845     genVariablePrinter(param, ctx, os);
846     if (param->isOptional())
847       os.unindent() << "}\n";
848   }
849   os.unindent() << "}\n";
850 }
851 
852 void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
853                                  MethodBody &os) {
854   genCommaSeparatedPrinter(llvm::to_vector(el->getParams()), ctx, os,
855                            [&](ParameterElement *param) {});
856 }
857 
858 void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
859                                  MethodBody &os) {
860   genCommaSeparatedPrinter(
861       llvm::to_vector(el->getParams()), ctx, os, [&](ParameterElement *param) {
862         os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
863       });
864 }
865 
866 void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
867                                  MethodBody &os) {
868   // Insert a space before the custom directive, if necessary.
869   if (shouldEmitSpace || !lastWasPunctuation)
870     os << tgfmt("$_printer << ' ';\n", &ctx);
871   shouldEmitSpace = true;
872   lastWasPunctuation = false;
873 
874   os << tgfmt("print$0($_printer", &ctx, el->getName());
875   os.indent();
876   for (FormatElement *arg : el->getArguments()) {
877     os << ",\n";
878     if (auto *param = dyn_cast<ParameterElement>(arg)) {
879       os << param->getParam().getAccessorName() << "()";
880     } else if (auto *ref = dyn_cast<RefDirective>(arg)) {
881       os << cast<ParameterElement>(ref->getArg())->getParam().getAccessorName()
882          << "()";
883     } else {
884       os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx);
885     }
886   }
887   os.unindent() << ");\n";
888 }
889 
890 void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
891                                         MethodBody &os) {
892   FormatElement *anchor = el->getAnchor();
893   if (auto *param = dyn_cast<ParameterElement>(anchor)) {
894     guardOnAny(ctx, os, llvm::ArrayRef(param), el->isInverted());
895   } else if (auto *params = dyn_cast<ParamsDirective>(anchor)) {
896     guardOnAny(ctx, os, params->getParams(), el->isInverted());
897   } else if (auto *strct = dyn_cast<StructDirective>(anchor)) {
898     guardOnAny(ctx, os, strct->getParams(), el->isInverted());
899   } else {
900     auto *custom = cast<CustomDirective>(anchor);
901     guardOnAny(ctx, os,
902                llvm::make_filter_range(
903                    llvm::map_range(custom->getArguments(),
904                                    [](FormatElement *el) {
905                                      return dyn_cast<ParameterElement>(el);
906                                    }),
907                    [](ParameterElement *param) { return !!param; }),
908                el->isInverted());
909   }
910   // Generate the printer for the contained elements.
911   {
912     llvm::SaveAndRestore shouldEmitSpaceFlag(shouldEmitSpace);
913     llvm::SaveAndRestore lastWasPunctuationFlag(lastWasPunctuation);
914     for (FormatElement *element : el->getThenElements())
915       genElementPrinter(element, ctx, os);
916   }
917   os.unindent() << "} else {\n";
918   os.indent();
919   for (FormatElement *element : el->getElseElements())
920     genElementPrinter(element, ctx, os);
921   os.unindent() << "}\n";
922 }
923 
924 void DefFormat::genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx,
925                                      MethodBody &os) {
926   if (el->getValue() == "\\n") {
927     // FIXME: The newline should be `printer.printNewLine()`, i.e., handled by
928     // the printer.
929     os << tgfmt("$_printer << '\\n';\n", &ctx);
930   } else if (!el->getValue().empty()) {
931     os << tgfmt("$_printer << \"$0\";\n", &ctx, el->getValue());
932   } else {
933     lastWasPunctuation = true;
934   }
935   shouldEmitSpace = false;
936 }
937 
938 //===----------------------------------------------------------------------===//
939 // DefFormatParser
940 //===----------------------------------------------------------------------===//
941 
942 namespace {
943 class DefFormatParser : public FormatParser {
944 public:
945   DefFormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def)
946       : FormatParser(mgr, def.getLoc()[0]), def(def),
947         seenParams(def.getNumParameters()) {}
948 
949   /// Parse the attribute or type format and create the format elements.
950   FailureOr<DefFormat> parse();
951 
952 protected:
953   /// Verify the parsed elements.
954   LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override;
955   /// Verify the elements of a custom directive.
956   LogicalResult
957   verifyCustomDirectiveArguments(SMLoc loc,
958                                  ArrayRef<FormatElement *> arguments) override;
959   /// Verify the elements of an optional group.
960   LogicalResult verifyOptionalGroupElements(SMLoc loc,
961                                             ArrayRef<FormatElement *> elements,
962                                             FormatElement *anchor) override;
963 
964   LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
965 
966   /// Parse an attribute or type variable.
967   FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
968                                                Context ctx) override;
969   /// Parse an attribute or type format directive.
970   FailureOr<FormatElement *>
971   parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override;
972 
973 private:
974   /// Parse a `params` directive.
975   FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx);
976   /// Parse a `struct` directive.
977   FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx);
978 
979   /// Attribute or type tablegen def.
980   const AttrOrTypeDef &def;
981 
982   /// Seen attribute or type parameters.
983   BitVector seenParams;
984 };
985 } // namespace
986 
987 LogicalResult DefFormatParser::verify(SMLoc loc,
988                                       ArrayRef<FormatElement *> elements) {
989   // Check that all parameters are referenced in the format.
990   for (auto [index, param] : llvm::enumerate(def.getParameters())) {
991     if (param.isOptional())
992       continue;
993     if (!seenParams.test(index)) {
994       if (isa<AttributeSelfTypeParameter>(param))
995         continue;
996       return emitError(loc, "format is missing reference to parameter: " +
997                                 param.getName());
998     }
999     if (isa<AttributeSelfTypeParameter>(param)) {
1000       return emitError(loc,
1001                        "unexpected self type parameter in assembly format");
1002     }
1003   }
1004   if (elements.empty())
1005     return success();
1006   // A `struct` directive that contains optional parameters cannot be followed
1007   // by a comma literal, which is ambiguous.
1008   for (auto it : llvm::zip(elements.drop_back(), elements.drop_front())) {
1009     auto *structEl = dyn_cast<StructDirective>(std::get<0>(it));
1010     auto *literalEl = dyn_cast<LiteralElement>(std::get<1>(it));
1011     if (!structEl || !literalEl)
1012       continue;
1013     if (literalEl->getSpelling() == "," && structEl->hasOptionalParams()) {
1014       return emitError(loc, "`struct` directive with optional parameters "
1015                             "cannot be followed by a comma literal");
1016     }
1017   }
1018   return success();
1019 }
1020 
1021 LogicalResult DefFormatParser::verifyCustomDirectiveArguments(
1022     SMLoc loc, ArrayRef<FormatElement *> arguments) {
1023   // Arguments are fully verified by the parser context.
1024   return success();
1025 }
1026 
1027 LogicalResult
1028 DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
1029                                              ArrayRef<FormatElement *> elements,
1030                                              FormatElement *anchor) {
1031   // `params` and `struct` directives are allowed only if all the contained
1032   // parameters are optional.
1033   for (FormatElement *el : elements) {
1034     if (auto *param = dyn_cast<ParameterElement>(el)) {
1035       if (!param->isOptional()) {
1036         return emitError(loc,
1037                          "parameters in an optional group must be optional");
1038       }
1039     } else if (auto *params = dyn_cast<ParamsDirective>(el)) {
1040       if (llvm::any_of(params->getParams(), paramNotOptional)) {
1041         return emitError(loc, "`params` directive allowed in optional group "
1042                               "only if all parameters are optional");
1043       }
1044     } else if (auto *strct = dyn_cast<StructDirective>(el)) {
1045       if (llvm::any_of(strct->getParams(), paramNotOptional)) {
1046         return emitError(loc, "`struct` is only allowed in an optional group "
1047                               "if all captured parameters are optional");
1048       }
1049     } else if (auto *custom = dyn_cast<CustomDirective>(el)) {
1050       for (FormatElement *el : custom->getArguments()) {
1051         // If the custom argument is a variable, then it must be optional.
1052         if (auto *param = dyn_cast<ParameterElement>(el))
1053           if (!param->isOptional())
1054             return emitError(loc,
1055                              "`custom` is only allowed in an optional group if "
1056                              "all captured parameters are optional");
1057       }
1058     }
1059   }
1060   // The anchor must be a parameter or one of the aforementioned directives.
1061   if (anchor) {
1062     if (!isa<ParameterElement, ParamsDirective, StructDirective,
1063              CustomDirective>(anchor)) {
1064       return emitError(
1065           loc, "optional group anchor must be a parameter or directive");
1066     }
1067     // If the anchor is a custom directive, make sure at least one of its
1068     // arguments is a bound parameter.
1069     if (auto *custom = dyn_cast<CustomDirective>(anchor)) {
1070       const auto *bound =
1071           llvm::find_if(custom->getArguments(), [](FormatElement *el) {
1072             return isa<ParameterElement>(el);
1073           });
1074       if (bound == custom->getArguments().end())
1075         return emitError(loc, "`custom` directive with no bound parameters "
1076                               "cannot be used as optional group anchor");
1077     }
1078   }
1079   return success();
1080 }
1081 
1082 LogicalResult DefFormatParser::markQualified(SMLoc loc,
1083                                              FormatElement *element) {
1084   if (!isa<ParameterElement>(element))
1085     return emitError(loc, "`qualified` argument list expected a variable");
1086   cast<ParameterElement>(element)->setShouldBeQualified();
1087   return success();
1088 }
1089 
1090 FailureOr<DefFormat> DefFormatParser::parse() {
1091   FailureOr<std::vector<FormatElement *>> elements = FormatParser::parse();
1092   if (failed(elements))
1093     return failure();
1094   return DefFormat(def, std::move(*elements));
1095 }
1096 
1097 FailureOr<FormatElement *>
1098 DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
1099   // Lookup the parameter.
1100   ArrayRef<AttrOrTypeParameter> params = def.getParameters();
1101   auto *it = llvm::find_if(
1102       params, [&](auto &param) { return param.getName() == name; });
1103 
1104   // Check that the parameter reference is valid.
1105   if (it == params.end()) {
1106     return emitError(loc,
1107                      def.getName() + " has no parameter named '" + name + "'");
1108   }
1109   auto idx = std::distance(params.begin(), it);
1110 
1111   if (ctx != RefDirectiveContext) {
1112     // Check that the variable has not already been bound.
1113     if (seenParams.test(idx))
1114       return emitError(loc, "duplicate parameter '" + name + "'");
1115     seenParams.set(idx);
1116 
1117     // Otherwise, to be referenced, a variable must have been bound.
1118   } else if (!seenParams.test(idx) && !isa<AttributeSelfTypeParameter>(*it)) {
1119     return emitError(loc, "parameter '" + name +
1120                               "' must be bound before it is referenced");
1121   }
1122 
1123   return create<ParameterElement>(*it);
1124 }
1125 
1126 FailureOr<FormatElement *>
1127 DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
1128                                     Context ctx) {
1129 
1130   switch (kind) {
1131   case FormatToken::kw_qualified:
1132     return parseQualifiedDirective(loc, ctx);
1133   case FormatToken::kw_params:
1134     return parseParamsDirective(loc, ctx);
1135   case FormatToken::kw_struct:
1136     return parseStructDirective(loc, ctx);
1137   default:
1138     return emitError(loc, "unsupported directive kind");
1139   }
1140 }
1141 
1142 FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc,
1143                                                                  Context ctx) {
1144   // It doesn't make sense to allow references to all parameters in a custom
1145   // directive because parameters are the only things that can be bound.
1146   if (ctx != TopLevelContext && ctx != StructDirectiveContext) {
1147     return emitError(loc, "`params` can only be used at the top-level context "
1148                           "or within a `struct` directive");
1149   }
1150 
1151   // Collect all of the attribute's or type's parameters and ensure that none of
1152   // the parameters have already been captured.
1153   std::vector<ParameterElement *> vars;
1154   for (const auto &it : llvm::enumerate(def.getParameters())) {
1155     if (seenParams.test(it.index())) {
1156       return emitError(loc, "`params` captures duplicate parameter: " +
1157                                 it.value().getName());
1158     }
1159     // Self-type parameters are handled separately from the rest of the
1160     // parameters.
1161     if (isa<AttributeSelfTypeParameter>(it.value()))
1162       continue;
1163     seenParams.set(it.index());
1164     vars.push_back(create<ParameterElement>(it.value()));
1165   }
1166   return create<ParamsDirective>(std::move(vars));
1167 }
1168 
1169 FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
1170                                                                  Context ctx) {
1171   if (ctx != TopLevelContext)
1172     return emitError(loc, "`struct` can only be used at the top-level context");
1173 
1174   if (failed(parseToken(FormatToken::l_paren,
1175                         "expected '(' before `struct` argument list")))
1176     return failure();
1177 
1178   // Parse variables captured by `struct`.
1179   std::vector<ParameterElement *> vars;
1180 
1181   // Parse first captured parameter or a `params` directive.
1182   FailureOr<FormatElement *> var = parseElement(StructDirectiveContext);
1183   if (failed(var) || !isa<VariableElement, ParamsDirective>(*var)) {
1184     return emitError(loc,
1185                      "`struct` argument list expected a variable or directive");
1186   }
1187   if (isa<VariableElement>(*var)) {
1188     // Parse any other parameters.
1189     vars.push_back(cast<ParameterElement>(*var));
1190     while (peekToken().is(FormatToken::comma)) {
1191       consumeToken();
1192       var = parseElement(StructDirectiveContext);
1193       if (failed(var) || !isa<VariableElement>(*var))
1194         return emitError(loc, "expected a variable in `struct` argument list");
1195       vars.push_back(cast<ParameterElement>(*var));
1196     }
1197   } else {
1198     // `struct(params)` captures all parameters in the attribute or type.
1199     vars = cast<ParamsDirective>(*var)->takeParams();
1200   }
1201 
1202   if (failed(parseToken(FormatToken::r_paren,
1203                         "expected ')' at the end of an argument list")))
1204     return failure();
1205 
1206   return create<StructDirective>(std::move(vars));
1207 }
1208 
1209 //===----------------------------------------------------------------------===//
1210 // Interface
1211 //===----------------------------------------------------------------------===//
1212 
1213 void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def,
1214                                             MethodBody &parser,
1215                                             MethodBody &printer) {
1216   llvm::SourceMgr mgr;
1217   mgr.AddNewSourceBuffer(
1218       llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()), SMLoc());
1219 
1220   // Parse the custom assembly format>
1221   DefFormatParser fmtParser(mgr, def);
1222   FailureOr<DefFormat> format = fmtParser.parse();
1223   if (failed(format)) {
1224     if (formatErrorIsFatal)
1225       PrintFatalError(def.getLoc(), "failed to parse assembly format");
1226     return;
1227   }
1228 
1229   // Generate the parser and printer.
1230   format->genParser(parser);
1231   format->genPrinter(printer);
1232 }
1233