1 //===- AttrOrTypeFormatGen.cpp - MLIR attribute and type format generator -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "AttrOrTypeFormatGen.h" 10 #include "FormatGen.h" 11 #include "mlir/Support/LLVM.h" 12 #include "mlir/TableGen/AttrOrTypeDef.h" 13 #include "mlir/TableGen/Format.h" 14 #include "mlir/TableGen/GenInfo.h" 15 #include "llvm/ADT/BitVector.h" 16 #include "llvm/ADT/StringExtras.h" 17 #include "llvm/ADT/StringSwitch.h" 18 #include "llvm/ADT/TypeSwitch.h" 19 #include "llvm/Support/MemoryBuffer.h" 20 #include "llvm/Support/SaveAndRestore.h" 21 #include "llvm/Support/SourceMgr.h" 22 #include "llvm/TableGen/Error.h" 23 #include "llvm/TableGen/TableGenBackend.h" 24 25 using namespace mlir; 26 using namespace mlir::tblgen; 27 28 using llvm::formatv; 29 30 //===----------------------------------------------------------------------===// 31 // Element 32 //===----------------------------------------------------------------------===// 33 34 namespace { 35 /// This class represents an instance of a variable element. A variable refers 36 /// to an attribute or type parameter. 37 class ParameterElement 38 : public VariableElementBase<VariableElement::Parameter> { 39 public: 40 ParameterElement(AttrOrTypeParameter param) : param(param) {} 41 42 /// Get the parameter in the element. 43 const AttrOrTypeParameter &getParam() const { return param; } 44 45 /// Indicate if this variable is printed "qualified" (that is it is 46 /// prefixed with the `#dialect.mnemonic`). 47 bool shouldBeQualified() { return shouldBeQualifiedFlag; } 48 void setShouldBeQualified(bool qualified = true) { 49 shouldBeQualifiedFlag = qualified; 50 } 51 52 /// Returns true if the element contains an optional parameter. 53 bool isOptional() const { return param.isOptional(); } 54 55 /// Returns the name of the parameter. 56 StringRef getName() const { return param.getName(); } 57 58 /// Return the code to check whether the parameter is present. 59 auto genIsPresent(FmtContext &ctx, const Twine &self) const { 60 assert(isOptional() && "cannot guard on a mandatory parameter"); 61 std::string valueStr = tgfmt(*param.getDefaultValue(), &ctx).str(); 62 ctx.addSubst("_lhs", self).addSubst("_rhs", valueStr); 63 return tgfmt(getParam().getComparator(), &ctx); 64 } 65 66 /// Generate the code to check whether the parameter should be printed. 67 MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const { 68 assert(isOptional() && "cannot guard on a mandatory parameter"); 69 std::string self = param.getAccessorName() + "()"; 70 return os << "!(" << genIsPresent(ctx, self) << ")"; 71 } 72 73 private: 74 bool shouldBeQualifiedFlag = false; 75 AttrOrTypeParameter param; 76 }; 77 78 /// Shorthand functions that can be used with ranged-based conditions. 79 static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); } 80 static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); } 81 82 /// Base class for a directive that contains references to multiple variables. 83 template <DirectiveElement::Kind DirectiveKind> 84 class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> { 85 public: 86 using Base = ParamsDirectiveBase<DirectiveKind>; 87 88 ParamsDirectiveBase(std::vector<ParameterElement *> &¶ms) 89 : params(std::move(params)) {} 90 91 /// Get the parameters contained in this directive. 92 ArrayRef<ParameterElement *> getParams() const { return params; } 93 94 /// Get the number of parameters. 95 unsigned getNumParams() const { return params.size(); } 96 97 /// Take all of the parameters from this directive. 98 std::vector<ParameterElement *> takeParams() { return std::move(params); } 99 100 /// Returns true if there are optional parameters present. 101 bool hasOptionalParams() const { 102 return llvm::any_of(getParams(), paramIsOptional); 103 } 104 105 private: 106 /// The parameters captured by this directive. 107 std::vector<ParameterElement *> params; 108 }; 109 110 /// This class represents a `params` directive that refers to all parameters 111 /// of an attribute or type. When used as a top-level directive, it generates 112 /// a format of the form: 113 /// 114 /// (param-value (`,` param-value)*)? 115 /// 116 /// When used as an argument to another directive that accepts variables, 117 /// `params` can be used in place of manually listing all parameters of an 118 /// attribute or type. 119 class ParamsDirective : public ParamsDirectiveBase<DirectiveElement::Params> { 120 public: 121 using Base::Base; 122 }; 123 124 /// This class represents a `struct` directive that generates a struct format 125 /// of the form: 126 /// 127 /// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}` 128 /// 129 class StructDirective : public ParamsDirectiveBase<DirectiveElement::Struct> { 130 public: 131 using Base::Base; 132 }; 133 134 } // namespace 135 136 //===----------------------------------------------------------------------===// 137 // Format Strings 138 //===----------------------------------------------------------------------===// 139 140 /// Default parser for attribute or type parameters. 141 static const char *const defaultParameterParser = 142 "::mlir::FieldParser<$0>::parse($_parser)"; 143 144 /// Default printer for attribute or type parameters. 145 static const char *const defaultParameterPrinter = 146 "$_printer.printStrippedAttrOrType($_self)"; 147 148 /// Qualified printer for attribute or type parameters: it does not elide 149 /// dialect and mnemonic. 150 static const char *const qualifiedParameterPrinter = "$_printer << $_self"; 151 152 /// Print an error when failing to parse an element. 153 /// 154 /// $0: The parameter C++ class name. 155 static const char *const parserErrorStr = 156 "$_parser.emitError($_parser.getCurrentLocation(), "; 157 158 /// Code format to parse a variable. Separate by lines because variable parsers 159 /// may be generated inside other directives, which requires indentation. 160 /// 161 /// {0}: The parameter name. 162 /// {1}: The parse code for the parameter. 163 /// {2}: Code template for printing an error. 164 /// {3}: Name of the attribute or type. 165 /// {4}: C++ class of the parameter. 166 /// {5}: Optional code to preload the dialect for this variable. 167 static const char *const variableParser = R"( 168 // Parse variable '{0}'{5} 169 _result_{0} = {1}; 170 if (::mlir::failed(_result_{0})) {{ 171 {2}"failed to parse {3} parameter '{0}' which is to be a `{4}`"); 172 return {{}; 173 } 174 )"; 175 176 //===----------------------------------------------------------------------===// 177 // DefFormat 178 //===----------------------------------------------------------------------===// 179 180 namespace { 181 class DefFormat { 182 public: 183 DefFormat(const AttrOrTypeDef &def, std::vector<FormatElement *> &&elements) 184 : def(def), elements(std::move(elements)) {} 185 186 /// Generate the attribute or type parser. 187 void genParser(MethodBody &os); 188 /// Generate the attribute or type printer. 189 void genPrinter(MethodBody &os); 190 191 private: 192 /// Generate the parser code for a specific format element. 193 void genElementParser(FormatElement *el, FmtContext &ctx, MethodBody &os); 194 /// Generate the parser code for a literal. 195 void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os, 196 bool isOptional = false); 197 /// Generate the parser code for a variable. 198 void genVariableParser(ParameterElement *el, FmtContext &ctx, MethodBody &os); 199 /// Generate the parser code for a `params` directive. 200 void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os); 201 /// Generate the parser code for a `struct` directive. 202 void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os); 203 /// Generate the parser code for a `custom` directive. 204 void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os, 205 bool isOptional = false); 206 /// Generate the parser code for an optional group. 207 void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx, 208 MethodBody &os); 209 210 /// Generate the printer code for a specific format element. 211 void genElementPrinter(FormatElement *el, FmtContext &ctx, MethodBody &os); 212 /// Generate the printer code for a literal. 213 void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os); 214 /// Generate the printer code for a variable. 215 void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os, 216 bool skipGuard = false); 217 /// Generate a printer for comma-separated parameters. 218 void genCommaSeparatedPrinter(ArrayRef<ParameterElement *> params, 219 FmtContext &ctx, MethodBody &os, 220 function_ref<void(ParameterElement *)> extra); 221 /// Generate the printer code for a `params` directive. 222 void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os); 223 /// Generate the printer code for a `struct` directive. 224 void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os); 225 /// Generate the printer code for a `custom` directive. 226 void genCustomPrinter(CustomDirective *el, FmtContext &ctx, MethodBody &os); 227 /// Generate the printer code for an optional group. 228 void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx, 229 MethodBody &os); 230 /// Generate a printer (or space eraser) for a whitespace element. 231 void genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx, 232 MethodBody &os); 233 234 /// The ODS definition of the attribute or type whose format is being used to 235 /// generate a parser and printer. 236 const AttrOrTypeDef &def; 237 /// The list of top-level format elements returned by the assembly format 238 /// parser. 239 std::vector<FormatElement *> elements; 240 241 /// Flags for printing spaces. 242 bool shouldEmitSpace = false; 243 bool lastWasPunctuation = false; 244 }; 245 } // namespace 246 247 //===----------------------------------------------------------------------===// 248 // ParserGen 249 //===----------------------------------------------------------------------===// 250 251 /// Generate a special-case "parser" for an attribute's self type parameter. The 252 /// self type parameter has special handling in the assembly format in that it 253 /// is derived from the optional trailing colon type after the attribute. 254 static void genAttrSelfTypeParser(MethodBody &os, const FmtContext &ctx, 255 const AttributeSelfTypeParameter ¶m) { 256 // "Parser" for an attribute self type parameter that checks the 257 // optionally-parsed trailing colon type. 258 // 259 // $0: The C++ storage class of the type parameter. 260 // $1: The self type parameter name. 261 const char *const selfTypeParser = R"( 262 if ($_type) { 263 if (auto reqType = ::llvm::dyn_cast<$0>($_type)) { 264 _result_$1 = reqType; 265 } else { 266 $_parser.emitError($_loc, "invalid kind of type specified"); 267 return {}; 268 } 269 })"; 270 271 // If the attribute self type parameter is required, emit code that emits an 272 // error if the trailing type was not parsed. 273 const char *const selfTypeRequired = R"( else { 274 $_parser.emitError($_loc, "expected a trailing type"); 275 return {}; 276 })"; 277 278 os << tgfmt(selfTypeParser, &ctx, param.getCppStorageType(), param.getName()); 279 if (!param.isOptional()) 280 os << tgfmt(selfTypeRequired, &ctx); 281 os << "\n"; 282 } 283 284 void DefFormat::genParser(MethodBody &os) { 285 FmtContext ctx; 286 ctx.addSubst("_parser", "odsParser"); 287 ctx.addSubst("_ctxt", "odsParser.getContext()"); 288 ctx.withBuilder("odsBuilder"); 289 if (isa<AttrDef>(def)) 290 ctx.addSubst("_type", "odsType"); 291 os.indent(); 292 os << "::mlir::Builder odsBuilder(odsParser.getContext());\n"; 293 294 // Store the initial location of the parser. 295 ctx.addSubst("_loc", "odsLoc"); 296 os << tgfmt("::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n" 297 "(void) $_loc;\n", 298 &ctx); 299 300 // Declare variables to store all of the parameters. Allocated parameters 301 // such as `ArrayRef` and `StringRef` must provide a `storageType`. Store 302 // FailureOr<T> to defer type construction for parameters that are parsed in 303 // a loop (parsers return FailureOr anyways). 304 ArrayRef<AttrOrTypeParameter> params = def.getParameters(); 305 for (const AttrOrTypeParameter ¶m : params) { 306 os << formatv("::mlir::FailureOr<{0}> _result_{1};\n", 307 param.getCppStorageType(), param.getName()); 308 if (auto *selfTypeParam = dyn_cast<AttributeSelfTypeParameter>(¶m)) 309 genAttrSelfTypeParser(os, ctx, *selfTypeParam); 310 } 311 312 // Generate call to each parameter parser. 313 for (FormatElement *el : elements) 314 genElementParser(el, ctx, os); 315 316 // Emit an assert for each mandatory parameter. Triggering an assert means 317 // the generated parser is incorrect (i.e. there is a bug in this code). 318 for (const AttrOrTypeParameter ¶m : params) { 319 if (param.isOptional()) 320 continue; 321 os << formatv("assert(::mlir::succeeded(_result_{0}));\n", param.getName()); 322 } 323 324 // Generate call to the attribute or type builder. Use the checked getter 325 // if one was generated. 326 if (def.genVerifyDecl() || def.genVerifyInvariantsImpl()) { 327 os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()", 328 &ctx, def.getCppClassName()); 329 } else { 330 os << tgfmt("return $0::get($_parser.getContext()", &ctx, 331 def.getCppClassName()); 332 } 333 for (const AttrOrTypeParameter ¶m : params) { 334 os << ",\n "; 335 std::string paramSelfStr; 336 llvm::raw_string_ostream selfOs(paramSelfStr); 337 if (std::optional<StringRef> defaultValue = param.getDefaultValue()) { 338 selfOs << formatv("(_result_{0}.value_or(", param.getName()) 339 << tgfmt(*defaultValue, &ctx) << "))"; 340 } else { 341 selfOs << formatv("(*_result_{0})", param.getName()); 342 } 343 ctx.addSubst(param.getName(), selfOs.str()); 344 os << param.getCppType() << "(" 345 << tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str())) 346 << ")"; 347 } 348 os << ");"; 349 } 350 351 void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx, 352 MethodBody &os) { 353 if (auto *literal = dyn_cast<LiteralElement>(el)) 354 return genLiteralParser(literal->getSpelling(), ctx, os); 355 if (auto *var = dyn_cast<ParameterElement>(el)) 356 return genVariableParser(var, ctx, os); 357 if (auto *params = dyn_cast<ParamsDirective>(el)) 358 return genParamsParser(params, ctx, os); 359 if (auto *strct = dyn_cast<StructDirective>(el)) 360 return genStructParser(strct, ctx, os); 361 if (auto *custom = dyn_cast<CustomDirective>(el)) 362 return genCustomParser(custom, ctx, os); 363 if (auto *optional = dyn_cast<OptionalElement>(el)) 364 return genOptionalGroupParser(optional, ctx, os); 365 if (isa<WhitespaceElement>(el)) 366 return; 367 368 llvm_unreachable("unknown format element"); 369 } 370 371 void DefFormat::genLiteralParser(StringRef value, FmtContext &ctx, 372 MethodBody &os, bool isOptional) { 373 os << "// Parse literal '" << value << "'\n"; 374 os << tgfmt("if ($_parser.parse", &ctx); 375 if (isOptional) 376 os << "Optional"; 377 if (value.front() == '_' || isalpha(value.front())) { 378 os << "Keyword(\"" << value << "\")"; 379 } else { 380 os << StringSwitch<StringRef>(value) 381 .Case("->", "Arrow") 382 .Case(":", "Colon") 383 .Case(",", "Comma") 384 .Case("=", "Equal") 385 .Case("<", "Less") 386 .Case(">", "Greater") 387 .Case("{", "LBrace") 388 .Case("}", "RBrace") 389 .Case("(", "LParen") 390 .Case(")", "RParen") 391 .Case("[", "LSquare") 392 .Case("]", "RSquare") 393 .Case("?", "Question") 394 .Case("+", "Plus") 395 .Case("*", "Star") 396 .Case("...", "Ellipsis") 397 << "()"; 398 } 399 if (isOptional) { 400 // Leave the `if` unclosed to guard optional groups. 401 return; 402 } 403 // Parser will emit an error 404 os << ") return {};\n"; 405 } 406 407 void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx, 408 MethodBody &os) { 409 // Check for a custom parser. Use the default attribute parser otherwise. 410 const AttrOrTypeParameter ¶m = el->getParam(); 411 auto customParser = param.getParser(); 412 auto parser = 413 customParser ? *customParser : StringRef(defaultParameterParser); 414 415 // If the variable points to a dialect specific entity (type of attribute), 416 // we force load the dialect now before trying to parse it. 417 std::string dialectLoading; 418 if (auto *defInit = dyn_cast<llvm::DefInit>(param.getDef())) { 419 auto *dialectValue = defInit->getDef()->getValue("dialect"); 420 if (dialectValue) { 421 if (auto *dialectInit = 422 dyn_cast<llvm::DefInit>(dialectValue->getValue())) { 423 Dialect dialect(dialectInit->getDef()); 424 auto cppNamespace = dialect.getCppNamespace(); 425 std::string name = dialect.getCppClassName(); 426 if (name != "BuiltinDialect" || cppNamespace != "::mlir") { 427 dialectLoading = ("\nodsParser.getContext()->getOrLoadDialect<" + 428 cppNamespace + "::" + name + ">();") 429 .str(); 430 } 431 } 432 } 433 } 434 os << formatv(variableParser, param.getName(), 435 tgfmt(parser, &ctx, param.getCppStorageType()), 436 tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType(), 437 dialectLoading); 438 } 439 440 void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx, 441 MethodBody &os) { 442 os << "// Parse parameter list\n"; 443 444 // If there are optional parameters, we need to switch to `parseOptionalComma` 445 // if there are no more required parameters after a certain point. 446 bool hasOptional = el->hasOptionalParams(); 447 if (hasOptional) { 448 // Wrap everything in a do-while so that we can `break`. 449 os << "do {\n"; 450 os.indent(); 451 } 452 453 ArrayRef<ParameterElement *> params = el->getParams(); 454 using IteratorT = ParameterElement *const *; 455 IteratorT it = params.begin(); 456 457 // Find the last required parameter. Commas become optional aftewards. 458 // Note: IteratorT's copy assignment is deleted. 459 ParameterElement *lastReq = nullptr; 460 for (ParameterElement *param : params) 461 if (!param->isOptional()) 462 lastReq = param; 463 IteratorT lastReqIt = lastReq ? llvm::find(params, lastReq) : params.begin(); 464 465 auto eachFn = [&](ParameterElement *el) { genVariableParser(el, ctx, os); }; 466 auto betweenFn = [&](IteratorT it) { 467 ParameterElement *el = *std::prev(it); 468 // Parse a comma if the last optional parameter had a value. 469 if (el->isOptional()) { 470 os << formatv("if (::mlir::succeeded(_result_{0}) && !({1})) {{\n", 471 el->getName(), 472 el->genIsPresent(ctx, "(*_result_" + el->getName() + ")")); 473 os.indent(); 474 } 475 if (it <= lastReqIt) { 476 genLiteralParser(",", ctx, os); 477 } else { 478 genLiteralParser(",", ctx, os, /*isOptional=*/true); 479 os << ") break;\n"; 480 } 481 if (el->isOptional()) 482 os.unindent() << "}\n"; 483 }; 484 485 // llvm::interleave 486 if (it != params.end()) { 487 eachFn(*it++); 488 for (IteratorT e = params.end(); it != e; ++it) { 489 betweenFn(it); 490 eachFn(*it); 491 } 492 } 493 494 if (hasOptional) 495 os.unindent() << "} while(false);\n"; 496 } 497 498 void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx, 499 MethodBody &os) { 500 // Loop declaration for struct parser with only required parameters. 501 // 502 // $0: Number of expected parameters. 503 const char *const loopHeader = R"( 504 for (unsigned odsStructIndex = 0; odsStructIndex < $0; ++odsStructIndex) { 505 )"; 506 507 // Loop body start for struct parser. 508 const char *const loopStart = R"( 509 ::llvm::StringRef _paramKey; 510 if ($_parser.parseKeyword(&_paramKey)) { 511 $_parser.emitError($_parser.getCurrentLocation(), 512 "expected a parameter name in struct"); 513 return {}; 514 } 515 if (!_loop_body(_paramKey)) return {}; 516 )"; 517 518 // Struct parser loop end. Check for duplicate or unknown struct parameters. 519 // 520 // {0}: Code template for printing an error. 521 const char *const loopEnd = R"({{ 522 {0}"duplicate or unknown struct parameter name: ") << _paramKey; 523 return {{}; 524 } 525 )"; 526 527 // Struct parser loop terminator. Parse a comma except on the last element. 528 // 529 // {0}: Number of elements in the struct. 530 const char *const loopTerminator = R"( 531 if ((odsStructIndex != {0} - 1) && odsParser.parseComma()) 532 return {{}; 533 } 534 )"; 535 536 // Check that a mandatory parameter was parse. 537 // 538 // {0}: Name of the parameter. 539 const char *const checkParam = R"( 540 if (!_seen_{0}) { 541 {1}"struct is missing required parameter: ") << "{0}"; 542 return {{}; 543 } 544 )"; 545 546 // First iteration of the loop parsing an optional struct. 547 const char *const optionalStructFirst = R"( 548 ::llvm::StringRef _paramKey; 549 if (!$_parser.parseOptionalKeyword(&_paramKey)) { 550 if (!_loop_body(_paramKey)) return {}; 551 while (!$_parser.parseOptionalComma()) { 552 )"; 553 554 os << "// Parse parameter struct\n"; 555 556 // Declare a "seen" variable for each key. 557 for (ParameterElement *param : el->getParams()) 558 os << formatv("bool _seen_{0} = false;\n", param->getName()); 559 560 // Generate the body of the parsing loop inside a lambda. 561 os << "{\n"; 562 os.indent() 563 << "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n"; 564 genLiteralParser("=", ctx, os.indent()); 565 for (ParameterElement *param : el->getParams()) { 566 os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n" 567 " _seen_{0} = true;\n", 568 param->getName()); 569 genVariableParser(param, ctx, os.indent()); 570 os.unindent() << "} else "; 571 // Print the check for duplicate or unknown parameter. 572 } 573 os.getStream().printReindented(strfmt(loopEnd, tgfmt(parserErrorStr, &ctx))); 574 os << "return true;\n"; 575 os.unindent() << "};\n"; 576 577 // Generate the parsing loop. If optional parameters are present, then the 578 // parse loop is guarded by commas. 579 unsigned numOptional = llvm::count_if(el->getParams(), paramIsOptional); 580 if (numOptional) { 581 // If the struct itself is optional, pull out the first iteration. 582 if (numOptional == el->getNumParams()) { 583 os.getStream().printReindented(tgfmt(optionalStructFirst, &ctx).str()); 584 os.indent(); 585 } else { 586 os << "do {\n"; 587 } 588 } else { 589 os.getStream().printReindented( 590 tgfmt(loopHeader, &ctx, el->getNumParams()).str()); 591 } 592 os.indent(); 593 os.getStream().printReindented(tgfmt(loopStart, &ctx).str()); 594 os.unindent(); 595 596 // Print the loop terminator. For optional parameters, we have to check that 597 // all mandatory parameters have been parsed. 598 // The whole struct is optional if all its parameters are optional. 599 if (numOptional) { 600 if (numOptional == el->getNumParams()) { 601 os << "}\n"; 602 os.unindent() << "}\n"; 603 } else { 604 os << tgfmt("} while(!$_parser.parseOptionalComma());\n", &ctx); 605 for (ParameterElement *param : el->getParams()) { 606 if (param->isOptional()) 607 continue; 608 os.getStream().printReindented( 609 strfmt(checkParam, param->getName(), tgfmt(parserErrorStr, &ctx))); 610 } 611 } 612 } else { 613 // Because the loop loops N times and each non-failing iteration sets 1 of 614 // N flags, successfully exiting the loop means that all parameters have 615 // been seen. `parseOptionalComma` would cause issues with any formats that 616 // use "struct(...) `,`" beacuse structs aren't sounded by braces. 617 os.getStream().printReindented(strfmt(loopTerminator, el->getNumParams())); 618 } 619 os.unindent() << "}\n"; 620 } 621 622 void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx, 623 MethodBody &os, bool isOptional) { 624 os << "{\n"; 625 os.indent(); 626 627 // Bound variables are passed directly to the parser as `FailureOr<T> &`. 628 // Referenced variables are passed as `T`. The custom parser fails if it 629 // returns failure or if any of the required parameters failed. 630 os << tgfmt("auto odsCustomLoc = $_parser.getCurrentLocation();\n", &ctx); 631 os << "(void)odsCustomLoc;\n"; 632 os << tgfmt("auto odsCustomResult = parse$0($_parser", &ctx, el->getName()); 633 os.indent(); 634 for (FormatElement *arg : el->getArguments()) { 635 os << ",\n"; 636 if (auto *param = dyn_cast<ParameterElement>(arg)) 637 os << "::mlir::detail::unwrapForCustomParse(_result_" << param->getName() 638 << ")"; 639 else if (auto *ref = dyn_cast<RefDirective>(arg)) 640 os << "*_result_" << cast<ParameterElement>(ref->getArg())->getName(); 641 else 642 os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx); 643 } 644 os.unindent() << ");\n"; 645 if (isOptional) { 646 os << "if (!odsCustomResult.has_value()) return {};\n"; 647 os << "if (::mlir::failed(*odsCustomResult)) return ::mlir::failure();\n"; 648 } else { 649 os << "if (::mlir::failed(odsCustomResult)) return {};\n"; 650 } 651 for (FormatElement *arg : el->getArguments()) { 652 if (auto *param = dyn_cast<ParameterElement>(arg)) { 653 if (param->isOptional()) 654 continue; 655 os << formatv("if (::mlir::failed(_result_{0})) {{\n", param->getName()); 656 os.indent() << tgfmt("$_parser.emitError(odsCustomLoc, ", &ctx) 657 << "\"custom parser failed to parse parameter '" 658 << param->getName() << "'\");\n"; 659 os << "return " << (isOptional ? "::mlir::failure()" : "{}") << ";\n"; 660 os.unindent() << "}\n"; 661 } 662 } 663 664 os.unindent() << "}\n"; 665 } 666 667 void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx, 668 MethodBody &os) { 669 ArrayRef<FormatElement *> thenElements = 670 el->getThenElements(/*parseable=*/true); 671 672 FormatElement *first = thenElements.front(); 673 const auto guardOn = [&](auto params) { 674 os << "if (!("; 675 llvm::interleave( 676 params, os, 677 [&](ParameterElement *el) { 678 os << formatv("(::mlir::succeeded(_result_{0}) && *_result_{0})", 679 el->getName()); 680 }, 681 " || "); 682 os << ")) {\n"; 683 }; 684 if (auto *literal = dyn_cast<LiteralElement>(first)) { 685 genLiteralParser(literal->getSpelling(), ctx, os, /*isOptional=*/true); 686 os << ") {\n"; 687 } else if (auto *param = dyn_cast<ParameterElement>(first)) { 688 genVariableParser(param, ctx, os); 689 guardOn(llvm::ArrayRef(param)); 690 } else if (auto *params = dyn_cast<ParamsDirective>(first)) { 691 genParamsParser(params, ctx, os); 692 guardOn(params->getParams()); 693 } else if (auto *custom = dyn_cast<CustomDirective>(first)) { 694 os << "if (auto result = [&]() -> ::mlir::OptionalParseResult {\n"; 695 os.indent(); 696 genCustomParser(custom, ctx, os, /*isOptional=*/true); 697 os << "return ::mlir::success();\n"; 698 os.unindent(); 699 os << "}(); result.has_value() && ::mlir::failed(*result)) {\n"; 700 os.indent(); 701 os << "return {};\n"; 702 os.unindent(); 703 os << "} else if (result.has_value()) {\n"; 704 } else { 705 auto *strct = cast<StructDirective>(first); 706 genStructParser(strct, ctx, os); 707 guardOn(params->getParams()); 708 } 709 os.indent(); 710 711 // Generate the parsers for the rest of the thenElements. 712 for (FormatElement *element : el->getElseElements(/*parseable=*/true)) 713 genElementParser(element, ctx, os); 714 os.unindent() << "} else {\n"; 715 os.indent(); 716 for (FormatElement *element : thenElements.drop_front()) 717 genElementParser(element, ctx, os); 718 os.unindent() << "}\n"; 719 } 720 721 //===----------------------------------------------------------------------===// 722 // PrinterGen 723 //===----------------------------------------------------------------------===// 724 725 void DefFormat::genPrinter(MethodBody &os) { 726 FmtContext ctx; 727 ctx.addSubst("_printer", "odsPrinter"); 728 ctx.addSubst("_ctxt", "getContext()"); 729 ctx.withBuilder("odsBuilder"); 730 os.indent(); 731 os << "::mlir::Builder odsBuilder(getContext());\n"; 732 733 // Generate printers. 734 shouldEmitSpace = true; 735 lastWasPunctuation = false; 736 for (FormatElement *el : elements) 737 genElementPrinter(el, ctx, os); 738 } 739 740 void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx, 741 MethodBody &os) { 742 if (auto *literal = dyn_cast<LiteralElement>(el)) 743 return genLiteralPrinter(literal->getSpelling(), ctx, os); 744 if (auto *params = dyn_cast<ParamsDirective>(el)) 745 return genParamsPrinter(params, ctx, os); 746 if (auto *strct = dyn_cast<StructDirective>(el)) 747 return genStructPrinter(strct, ctx, os); 748 if (auto *custom = dyn_cast<CustomDirective>(el)) 749 return genCustomPrinter(custom, ctx, os); 750 if (auto *var = dyn_cast<ParameterElement>(el)) 751 return genVariablePrinter(var, ctx, os); 752 if (auto *optional = dyn_cast<OptionalElement>(el)) 753 return genOptionalGroupPrinter(optional, ctx, os); 754 if (auto *whitespace = dyn_cast<WhitespaceElement>(el)) 755 return genWhitespacePrinter(whitespace, ctx, os); 756 757 llvm::PrintFatalError("unsupported format element"); 758 } 759 760 void DefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx, 761 MethodBody &os) { 762 // Don't insert a space before certain punctuation. 763 bool needSpace = 764 shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation); 765 os << tgfmt("$_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "", 766 value); 767 768 // Update the flags. 769 shouldEmitSpace = 770 value.size() != 1 || !StringRef("<({[").contains(value.front()); 771 lastWasPunctuation = value.front() != '_' && !isalpha(value.front()); 772 } 773 774 void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx, 775 MethodBody &os, bool skipGuard) { 776 const AttrOrTypeParameter ¶m = el->getParam(); 777 ctx.withSelf(param.getAccessorName() + "()"); 778 779 // Guard the printer on the presence of optional parameters and that they 780 // aren't equal to their default values (if they have one). 781 if (el->isOptional() && !skipGuard) { 782 el->genPrintGuard(ctx, os << "if (") << ") {\n"; 783 os.indent(); 784 } 785 786 // Insert a space before the next parameter, if necessary. 787 if (shouldEmitSpace || !lastWasPunctuation) 788 os << tgfmt("$_printer << ' ';\n", &ctx); 789 shouldEmitSpace = true; 790 lastWasPunctuation = false; 791 792 if (el->shouldBeQualified()) 793 os << tgfmt(qualifiedParameterPrinter, &ctx) << ";\n"; 794 else if (auto printer = param.getPrinter()) 795 os << tgfmt(*printer, &ctx) << ";\n"; 796 else 797 os << tgfmt(defaultParameterPrinter, &ctx) << ";\n"; 798 799 if (el->isOptional() && !skipGuard) 800 os.unindent() << "}\n"; 801 } 802 803 /// Generate code to guard printing on the presence of any optional parameters. 804 template <typename ParameterRange> 805 static void guardOnAny(FmtContext &ctx, MethodBody &os, ParameterRange &¶ms, 806 bool inverted = false) { 807 os << "if ("; 808 if (inverted) 809 os << "!("; 810 llvm::interleave( 811 params, os, 812 [&](ParameterElement *param) { param->genPrintGuard(ctx, os); }, " || "); 813 if (inverted) 814 os << ")"; 815 os << ") {\n"; 816 os.indent(); 817 } 818 819 void DefFormat::genCommaSeparatedPrinter( 820 ArrayRef<ParameterElement *> params, FmtContext &ctx, MethodBody &os, 821 function_ref<void(ParameterElement *)> extra) { 822 // Emit a space if necessary, but only if the struct is present. 823 if (shouldEmitSpace || !lastWasPunctuation) { 824 bool allOptional = llvm::all_of(params, paramIsOptional); 825 if (allOptional) 826 guardOnAny(ctx, os, params); 827 os << tgfmt("$_printer << ' ';\n", &ctx); 828 if (allOptional) 829 os.unindent() << "}\n"; 830 } 831 832 // The first printed element does not need to emit a comma. 833 os << "{\n"; 834 os.indent() << "bool _firstPrinted = true;\n"; 835 for (ParameterElement *param : params) { 836 if (param->isOptional()) { 837 param->genPrintGuard(ctx, os << "if (") << ") {\n"; 838 os.indent(); 839 } 840 os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx); 841 os << "_firstPrinted = false;\n"; 842 extra(param); 843 shouldEmitSpace = false; 844 lastWasPunctuation = true; 845 genVariablePrinter(param, ctx, os); 846 if (param->isOptional()) 847 os.unindent() << "}\n"; 848 } 849 os.unindent() << "}\n"; 850 } 851 852 void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx, 853 MethodBody &os) { 854 genCommaSeparatedPrinter(llvm::to_vector(el->getParams()), ctx, os, 855 [&](ParameterElement *param) {}); 856 } 857 858 void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx, 859 MethodBody &os) { 860 genCommaSeparatedPrinter( 861 llvm::to_vector(el->getParams()), ctx, os, [&](ParameterElement *param) { 862 os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName()); 863 }); 864 } 865 866 void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx, 867 MethodBody &os) { 868 // Insert a space before the custom directive, if necessary. 869 if (shouldEmitSpace || !lastWasPunctuation) 870 os << tgfmt("$_printer << ' ';\n", &ctx); 871 shouldEmitSpace = true; 872 lastWasPunctuation = false; 873 874 os << tgfmt("print$0($_printer", &ctx, el->getName()); 875 os.indent(); 876 for (FormatElement *arg : el->getArguments()) { 877 os << ",\n"; 878 if (auto *param = dyn_cast<ParameterElement>(arg)) { 879 os << param->getParam().getAccessorName() << "()"; 880 } else if (auto *ref = dyn_cast<RefDirective>(arg)) { 881 os << cast<ParameterElement>(ref->getArg())->getParam().getAccessorName() 882 << "()"; 883 } else { 884 os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx); 885 } 886 } 887 os.unindent() << ");\n"; 888 } 889 890 void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx, 891 MethodBody &os) { 892 FormatElement *anchor = el->getAnchor(); 893 if (auto *param = dyn_cast<ParameterElement>(anchor)) { 894 guardOnAny(ctx, os, llvm::ArrayRef(param), el->isInverted()); 895 } else if (auto *params = dyn_cast<ParamsDirective>(anchor)) { 896 guardOnAny(ctx, os, params->getParams(), el->isInverted()); 897 } else if (auto *strct = dyn_cast<StructDirective>(anchor)) { 898 guardOnAny(ctx, os, strct->getParams(), el->isInverted()); 899 } else { 900 auto *custom = cast<CustomDirective>(anchor); 901 guardOnAny(ctx, os, 902 llvm::make_filter_range( 903 llvm::map_range(custom->getArguments(), 904 [](FormatElement *el) { 905 return dyn_cast<ParameterElement>(el); 906 }), 907 [](ParameterElement *param) { return !!param; }), 908 el->isInverted()); 909 } 910 // Generate the printer for the contained elements. 911 { 912 llvm::SaveAndRestore shouldEmitSpaceFlag(shouldEmitSpace); 913 llvm::SaveAndRestore lastWasPunctuationFlag(lastWasPunctuation); 914 for (FormatElement *element : el->getThenElements()) 915 genElementPrinter(element, ctx, os); 916 } 917 os.unindent() << "} else {\n"; 918 os.indent(); 919 for (FormatElement *element : el->getElseElements()) 920 genElementPrinter(element, ctx, os); 921 os.unindent() << "}\n"; 922 } 923 924 void DefFormat::genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx, 925 MethodBody &os) { 926 if (el->getValue() == "\\n") { 927 // FIXME: The newline should be `printer.printNewLine()`, i.e., handled by 928 // the printer. 929 os << tgfmt("$_printer << '\\n';\n", &ctx); 930 } else if (!el->getValue().empty()) { 931 os << tgfmt("$_printer << \"$0\";\n", &ctx, el->getValue()); 932 } else { 933 lastWasPunctuation = true; 934 } 935 shouldEmitSpace = false; 936 } 937 938 //===----------------------------------------------------------------------===// 939 // DefFormatParser 940 //===----------------------------------------------------------------------===// 941 942 namespace { 943 class DefFormatParser : public FormatParser { 944 public: 945 DefFormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def) 946 : FormatParser(mgr, def.getLoc()[0]), def(def), 947 seenParams(def.getNumParameters()) {} 948 949 /// Parse the attribute or type format and create the format elements. 950 FailureOr<DefFormat> parse(); 951 952 protected: 953 /// Verify the parsed elements. 954 LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override; 955 /// Verify the elements of a custom directive. 956 LogicalResult 957 verifyCustomDirectiveArguments(SMLoc loc, 958 ArrayRef<FormatElement *> arguments) override; 959 /// Verify the elements of an optional group. 960 LogicalResult verifyOptionalGroupElements(SMLoc loc, 961 ArrayRef<FormatElement *> elements, 962 FormatElement *anchor) override; 963 964 LogicalResult markQualified(SMLoc loc, FormatElement *element) override; 965 966 /// Parse an attribute or type variable. 967 FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name, 968 Context ctx) override; 969 /// Parse an attribute or type format directive. 970 FailureOr<FormatElement *> 971 parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override; 972 973 private: 974 /// Parse a `params` directive. 975 FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx); 976 /// Parse a `struct` directive. 977 FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx); 978 979 /// Attribute or type tablegen def. 980 const AttrOrTypeDef &def; 981 982 /// Seen attribute or type parameters. 983 BitVector seenParams; 984 }; 985 } // namespace 986 987 LogicalResult DefFormatParser::verify(SMLoc loc, 988 ArrayRef<FormatElement *> elements) { 989 // Check that all parameters are referenced in the format. 990 for (auto [index, param] : llvm::enumerate(def.getParameters())) { 991 if (param.isOptional()) 992 continue; 993 if (!seenParams.test(index)) { 994 if (isa<AttributeSelfTypeParameter>(param)) 995 continue; 996 return emitError(loc, "format is missing reference to parameter: " + 997 param.getName()); 998 } 999 if (isa<AttributeSelfTypeParameter>(param)) { 1000 return emitError(loc, 1001 "unexpected self type parameter in assembly format"); 1002 } 1003 } 1004 if (elements.empty()) 1005 return success(); 1006 // A `struct` directive that contains optional parameters cannot be followed 1007 // by a comma literal, which is ambiguous. 1008 for (auto it : llvm::zip(elements.drop_back(), elements.drop_front())) { 1009 auto *structEl = dyn_cast<StructDirective>(std::get<0>(it)); 1010 auto *literalEl = dyn_cast<LiteralElement>(std::get<1>(it)); 1011 if (!structEl || !literalEl) 1012 continue; 1013 if (literalEl->getSpelling() == "," && structEl->hasOptionalParams()) { 1014 return emitError(loc, "`struct` directive with optional parameters " 1015 "cannot be followed by a comma literal"); 1016 } 1017 } 1018 return success(); 1019 } 1020 1021 LogicalResult DefFormatParser::verifyCustomDirectiveArguments( 1022 SMLoc loc, ArrayRef<FormatElement *> arguments) { 1023 // Arguments are fully verified by the parser context. 1024 return success(); 1025 } 1026 1027 LogicalResult 1028 DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc, 1029 ArrayRef<FormatElement *> elements, 1030 FormatElement *anchor) { 1031 // `params` and `struct` directives are allowed only if all the contained 1032 // parameters are optional. 1033 for (FormatElement *el : elements) { 1034 if (auto *param = dyn_cast<ParameterElement>(el)) { 1035 if (!param->isOptional()) { 1036 return emitError(loc, 1037 "parameters in an optional group must be optional"); 1038 } 1039 } else if (auto *params = dyn_cast<ParamsDirective>(el)) { 1040 if (llvm::any_of(params->getParams(), paramNotOptional)) { 1041 return emitError(loc, "`params` directive allowed in optional group " 1042 "only if all parameters are optional"); 1043 } 1044 } else if (auto *strct = dyn_cast<StructDirective>(el)) { 1045 if (llvm::any_of(strct->getParams(), paramNotOptional)) { 1046 return emitError(loc, "`struct` is only allowed in an optional group " 1047 "if all captured parameters are optional"); 1048 } 1049 } else if (auto *custom = dyn_cast<CustomDirective>(el)) { 1050 for (FormatElement *el : custom->getArguments()) { 1051 // If the custom argument is a variable, then it must be optional. 1052 if (auto *param = dyn_cast<ParameterElement>(el)) 1053 if (!param->isOptional()) 1054 return emitError(loc, 1055 "`custom` is only allowed in an optional group if " 1056 "all captured parameters are optional"); 1057 } 1058 } 1059 } 1060 // The anchor must be a parameter or one of the aforementioned directives. 1061 if (anchor) { 1062 if (!isa<ParameterElement, ParamsDirective, StructDirective, 1063 CustomDirective>(anchor)) { 1064 return emitError( 1065 loc, "optional group anchor must be a parameter or directive"); 1066 } 1067 // If the anchor is a custom directive, make sure at least one of its 1068 // arguments is a bound parameter. 1069 if (auto *custom = dyn_cast<CustomDirective>(anchor)) { 1070 const auto *bound = 1071 llvm::find_if(custom->getArguments(), [](FormatElement *el) { 1072 return isa<ParameterElement>(el); 1073 }); 1074 if (bound == custom->getArguments().end()) 1075 return emitError(loc, "`custom` directive with no bound parameters " 1076 "cannot be used as optional group anchor"); 1077 } 1078 } 1079 return success(); 1080 } 1081 1082 LogicalResult DefFormatParser::markQualified(SMLoc loc, 1083 FormatElement *element) { 1084 if (!isa<ParameterElement>(element)) 1085 return emitError(loc, "`qualified` argument list expected a variable"); 1086 cast<ParameterElement>(element)->setShouldBeQualified(); 1087 return success(); 1088 } 1089 1090 FailureOr<DefFormat> DefFormatParser::parse() { 1091 FailureOr<std::vector<FormatElement *>> elements = FormatParser::parse(); 1092 if (failed(elements)) 1093 return failure(); 1094 return DefFormat(def, std::move(*elements)); 1095 } 1096 1097 FailureOr<FormatElement *> 1098 DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) { 1099 // Lookup the parameter. 1100 ArrayRef<AttrOrTypeParameter> params = def.getParameters(); 1101 auto *it = llvm::find_if( 1102 params, [&](auto ¶m) { return param.getName() == name; }); 1103 1104 // Check that the parameter reference is valid. 1105 if (it == params.end()) { 1106 return emitError(loc, 1107 def.getName() + " has no parameter named '" + name + "'"); 1108 } 1109 auto idx = std::distance(params.begin(), it); 1110 1111 if (ctx != RefDirectiveContext) { 1112 // Check that the variable has not already been bound. 1113 if (seenParams.test(idx)) 1114 return emitError(loc, "duplicate parameter '" + name + "'"); 1115 seenParams.set(idx); 1116 1117 // Otherwise, to be referenced, a variable must have been bound. 1118 } else if (!seenParams.test(idx) && !isa<AttributeSelfTypeParameter>(*it)) { 1119 return emitError(loc, "parameter '" + name + 1120 "' must be bound before it is referenced"); 1121 } 1122 1123 return create<ParameterElement>(*it); 1124 } 1125 1126 FailureOr<FormatElement *> 1127 DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, 1128 Context ctx) { 1129 1130 switch (kind) { 1131 case FormatToken::kw_qualified: 1132 return parseQualifiedDirective(loc, ctx); 1133 case FormatToken::kw_params: 1134 return parseParamsDirective(loc, ctx); 1135 case FormatToken::kw_struct: 1136 return parseStructDirective(loc, ctx); 1137 default: 1138 return emitError(loc, "unsupported directive kind"); 1139 } 1140 } 1141 1142 FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc, 1143 Context ctx) { 1144 // It doesn't make sense to allow references to all parameters in a custom 1145 // directive because parameters are the only things that can be bound. 1146 if (ctx != TopLevelContext && ctx != StructDirectiveContext) { 1147 return emitError(loc, "`params` can only be used at the top-level context " 1148 "or within a `struct` directive"); 1149 } 1150 1151 // Collect all of the attribute's or type's parameters and ensure that none of 1152 // the parameters have already been captured. 1153 std::vector<ParameterElement *> vars; 1154 for (const auto &it : llvm::enumerate(def.getParameters())) { 1155 if (seenParams.test(it.index())) { 1156 return emitError(loc, "`params` captures duplicate parameter: " + 1157 it.value().getName()); 1158 } 1159 // Self-type parameters are handled separately from the rest of the 1160 // parameters. 1161 if (isa<AttributeSelfTypeParameter>(it.value())) 1162 continue; 1163 seenParams.set(it.index()); 1164 vars.push_back(create<ParameterElement>(it.value())); 1165 } 1166 return create<ParamsDirective>(std::move(vars)); 1167 } 1168 1169 FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc, 1170 Context ctx) { 1171 if (ctx != TopLevelContext) 1172 return emitError(loc, "`struct` can only be used at the top-level context"); 1173 1174 if (failed(parseToken(FormatToken::l_paren, 1175 "expected '(' before `struct` argument list"))) 1176 return failure(); 1177 1178 // Parse variables captured by `struct`. 1179 std::vector<ParameterElement *> vars; 1180 1181 // Parse first captured parameter or a `params` directive. 1182 FailureOr<FormatElement *> var = parseElement(StructDirectiveContext); 1183 if (failed(var) || !isa<VariableElement, ParamsDirective>(*var)) { 1184 return emitError(loc, 1185 "`struct` argument list expected a variable or directive"); 1186 } 1187 if (isa<VariableElement>(*var)) { 1188 // Parse any other parameters. 1189 vars.push_back(cast<ParameterElement>(*var)); 1190 while (peekToken().is(FormatToken::comma)) { 1191 consumeToken(); 1192 var = parseElement(StructDirectiveContext); 1193 if (failed(var) || !isa<VariableElement>(*var)) 1194 return emitError(loc, "expected a variable in `struct` argument list"); 1195 vars.push_back(cast<ParameterElement>(*var)); 1196 } 1197 } else { 1198 // `struct(params)` captures all parameters in the attribute or type. 1199 vars = cast<ParamsDirective>(*var)->takeParams(); 1200 } 1201 1202 if (failed(parseToken(FormatToken::r_paren, 1203 "expected ')' at the end of an argument list"))) 1204 return failure(); 1205 1206 return create<StructDirective>(std::move(vars)); 1207 } 1208 1209 //===----------------------------------------------------------------------===// 1210 // Interface 1211 //===----------------------------------------------------------------------===// 1212 1213 void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def, 1214 MethodBody &parser, 1215 MethodBody &printer) { 1216 llvm::SourceMgr mgr; 1217 mgr.AddNewSourceBuffer( 1218 llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()), SMLoc()); 1219 1220 // Parse the custom assembly format> 1221 DefFormatParser fmtParser(mgr, def); 1222 FailureOr<DefFormat> format = fmtParser.parse(); 1223 if (failed(format)) { 1224 if (formatErrorIsFatal) 1225 PrintFatalError(def.getLoc(), "failed to parse assembly format"); 1226 return; 1227 } 1228 1229 // Generate the parser and printer. 1230 format->genParser(parser); 1231 format->genPrinter(printer); 1232 } 1233