1 //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the parser for the MLIR Types. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Parser.h" 14 15 #include "AsmParserImpl.h" 16 #include "mlir/AsmParser/AsmParserState.h" 17 #include "mlir/IR/AffineMap.h" 18 #include "mlir/IR/BuiltinAttributes.h" 19 #include "mlir/IR/BuiltinDialect.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/IR/DialectImplementation.h" 22 #include "mlir/IR/DialectResourceBlobManager.h" 23 #include "mlir/IR/IntegerSet.h" 24 #include "llvm/ADT/StringExtras.h" 25 #include "llvm/Support/Endian.h" 26 #include <optional> 27 28 using namespace mlir; 29 using namespace mlir::detail; 30 31 /// Parse an arbitrary attribute. 32 /// 33 /// attribute-value ::= `unit` 34 /// | bool-literal 35 /// | integer-literal (`:` (index-type | integer-type))? 36 /// | float-literal (`:` float-type)? 37 /// | string-literal (`:` type)? 38 /// | type 39 /// | `[` `:` (integer-type | float-type) tensor-literal `]` 40 /// | `[` (attribute-value (`,` attribute-value)*)? `]` 41 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` 42 /// | symbol-ref-id (`::` symbol-ref-id)* 43 /// | `dense` `<` tensor-literal `>` `:` 44 /// (tensor-type | vector-type) 45 /// | `sparse` `<` attribute-value `,` attribute-value `>` 46 /// `:` (tensor-type | vector-type) 47 /// | `strided` `<` `[` comma-separated-int-or-question `]` 48 /// (`,` `offset` `:` integer-literal)? `>` 49 /// | distinct-attribute 50 /// | extended-attribute 51 /// 52 Attribute Parser::parseAttribute(Type type) { 53 switch (getToken().getKind()) { 54 // Parse an AffineMap or IntegerSet attribute. 55 case Token::kw_affine_map: { 56 consumeToken(Token::kw_affine_map); 57 58 AffineMap map; 59 if (parseToken(Token::less, "expected '<' in affine map") || 60 parseAffineMapReference(map) || 61 parseToken(Token::greater, "expected '>' in affine map")) 62 return Attribute(); 63 return AffineMapAttr::get(map); 64 } 65 case Token::kw_affine_set: { 66 consumeToken(Token::kw_affine_set); 67 68 IntegerSet set; 69 if (parseToken(Token::less, "expected '<' in integer set") || 70 parseIntegerSetReference(set) || 71 parseToken(Token::greater, "expected '>' in integer set")) 72 return Attribute(); 73 return IntegerSetAttr::get(set); 74 } 75 76 // Parse an array attribute. 77 case Token::l_square: { 78 consumeToken(Token::l_square); 79 SmallVector<Attribute, 4> elements; 80 auto parseElt = [&]() -> ParseResult { 81 elements.push_back(parseAttribute()); 82 return elements.back() ? success() : failure(); 83 }; 84 85 if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) 86 return nullptr; 87 return builder.getArrayAttr(elements); 88 } 89 90 // Parse a boolean attribute. 91 case Token::kw_false: 92 consumeToken(Token::kw_false); 93 return builder.getBoolAttr(false); 94 case Token::kw_true: 95 consumeToken(Token::kw_true); 96 return builder.getBoolAttr(true); 97 98 // Parse a dense elements attribute. 99 case Token::kw_dense: 100 return parseDenseElementsAttr(type); 101 102 // Parse a dense resource elements attribute. 103 case Token::kw_dense_resource: 104 return parseDenseResourceElementsAttr(type); 105 106 // Parse a dense array attribute. 107 case Token::kw_array: 108 return parseDenseArrayAttr(type); 109 110 // Parse a dictionary attribute. 111 case Token::l_brace: { 112 NamedAttrList elements; 113 if (parseAttributeDict(elements)) 114 return nullptr; 115 return elements.getDictionary(getContext()); 116 } 117 118 // Parse an extended attribute, i.e. alias or dialect attribute. 119 case Token::hash_identifier: 120 return parseExtendedAttr(type); 121 122 // Parse floating point and integer attributes. 123 case Token::floatliteral: 124 return parseFloatAttr(type, /*isNegative=*/false); 125 case Token::integer: 126 return parseDecOrHexAttr(type, /*isNegative=*/false); 127 case Token::minus: { 128 consumeToken(Token::minus); 129 if (getToken().is(Token::integer)) 130 return parseDecOrHexAttr(type, /*isNegative=*/true); 131 if (getToken().is(Token::floatliteral)) 132 return parseFloatAttr(type, /*isNegative=*/true); 133 134 return (emitWrongTokenError( 135 "expected constant integer or floating point value"), 136 nullptr); 137 } 138 139 // Parse a location attribute. 140 case Token::kw_loc: { 141 consumeToken(Token::kw_loc); 142 143 LocationAttr locAttr; 144 if (parseToken(Token::l_paren, "expected '(' in inline location") || 145 parseLocationInstance(locAttr) || 146 parseToken(Token::r_paren, "expected ')' in inline location")) 147 return Attribute(); 148 return locAttr; 149 } 150 151 // Parse a sparse elements attribute. 152 case Token::kw_sparse: 153 return parseSparseElementsAttr(type); 154 155 // Parse a strided layout attribute. 156 case Token::kw_strided: 157 return parseStridedLayoutAttr(); 158 159 // Parse a distinct attribute. 160 case Token::kw_distinct: 161 return parseDistinctAttr(type); 162 163 // Parse a string attribute. 164 case Token::string: { 165 auto val = getToken().getStringValue(); 166 consumeToken(Token::string); 167 // Parse the optional trailing colon type if one wasn't explicitly provided. 168 if (!type && consumeIf(Token::colon) && !(type = parseType())) 169 return Attribute(); 170 171 return type ? StringAttr::get(val, type) 172 : StringAttr::get(getContext(), val); 173 } 174 175 // Parse a symbol reference attribute. 176 case Token::at_identifier: { 177 // When populating the parser state, this is a list of locations for all of 178 // the nested references. 179 SmallVector<SMRange> referenceLocations; 180 if (state.asmState) 181 referenceLocations.push_back(getToken().getLocRange()); 182 183 // Parse the top-level reference. 184 std::string nameStr = getToken().getSymbolReference(); 185 consumeToken(Token::at_identifier); 186 187 // Parse any nested references. 188 std::vector<FlatSymbolRefAttr> nestedRefs; 189 while (getToken().is(Token::colon)) { 190 // Check for the '::' prefix. 191 const char *curPointer = getToken().getLoc().getPointer(); 192 consumeToken(Token::colon); 193 if (!consumeIf(Token::colon)) { 194 if (getToken().isNot(Token::eof, Token::error)) { 195 state.lex.resetPointer(curPointer); 196 consumeToken(); 197 } 198 break; 199 } 200 // Parse the reference itself. 201 auto curLoc = getToken().getLoc(); 202 if (getToken().isNot(Token::at_identifier)) { 203 emitError(curLoc, "expected nested symbol reference identifier"); 204 return Attribute(); 205 } 206 207 // If we are populating the assembly state, add the location for this 208 // reference. 209 if (state.asmState) 210 referenceLocations.push_back(getToken().getLocRange()); 211 212 std::string nameStr = getToken().getSymbolReference(); 213 consumeToken(Token::at_identifier); 214 nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr)); 215 } 216 SymbolRefAttr symbolRefAttr = 217 SymbolRefAttr::get(getContext(), nameStr, nestedRefs); 218 219 // If we are populating the assembly state, record this symbol reference. 220 if (state.asmState) 221 state.asmState->addUses(symbolRefAttr, referenceLocations); 222 return symbolRefAttr; 223 } 224 225 // Parse a 'unit' attribute. 226 case Token::kw_unit: 227 consumeToken(Token::kw_unit); 228 return builder.getUnitAttr(); 229 230 // Handle completion of an attribute. 231 case Token::code_complete: 232 if (getToken().isCodeCompletionFor(Token::hash_identifier)) 233 return parseExtendedAttr(type); 234 return codeCompleteAttribute(); 235 236 default: 237 // Parse a type attribute. We parse `Optional` here to allow for providing a 238 // better error message. 239 Type type; 240 OptionalParseResult result = parseOptionalType(type); 241 if (!result.has_value()) 242 return emitWrongTokenError("expected attribute value"), Attribute(); 243 return failed(*result) ? Attribute() : TypeAttr::get(type); 244 } 245 } 246 247 /// Parse an optional attribute with the provided type. 248 OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute, 249 Type type) { 250 switch (getToken().getKind()) { 251 case Token::at_identifier: 252 case Token::floatliteral: 253 case Token::integer: 254 case Token::hash_identifier: 255 case Token::kw_affine_map: 256 case Token::kw_affine_set: 257 case Token::kw_dense: 258 case Token::kw_dense_resource: 259 case Token::kw_false: 260 case Token::kw_loc: 261 case Token::kw_sparse: 262 case Token::kw_true: 263 case Token::kw_unit: 264 case Token::l_brace: 265 case Token::l_square: 266 case Token::minus: 267 case Token::string: 268 attribute = parseAttribute(type); 269 return success(attribute != nullptr); 270 271 default: 272 // Parse an optional type attribute. 273 Type type; 274 OptionalParseResult result = parseOptionalType(type); 275 if (result.has_value() && succeeded(*result)) 276 attribute = TypeAttr::get(type); 277 return result; 278 } 279 } 280 OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute, 281 Type type) { 282 return parseOptionalAttributeWithToken(Token::l_square, attribute, type); 283 } 284 OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute, 285 Type type) { 286 return parseOptionalAttributeWithToken(Token::string, attribute, type); 287 } 288 OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result, 289 Type type) { 290 return parseOptionalAttributeWithToken(Token::at_identifier, result, type); 291 } 292 293 /// Attribute dictionary. 294 /// 295 /// attribute-dict ::= `{` `}` 296 /// | `{` attribute-entry (`,` attribute-entry)* `}` 297 /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value 298 /// 299 ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { 300 llvm::SmallDenseSet<StringAttr> seenKeys; 301 auto parseElt = [&]() -> ParseResult { 302 // The name of an attribute can either be a bare identifier, or a string. 303 std::optional<StringAttr> nameId; 304 if (getToken().is(Token::string)) 305 nameId = builder.getStringAttr(getToken().getStringValue()); 306 else if (getToken().isAny(Token::bare_identifier, Token::inttype) || 307 getToken().isKeyword()) 308 nameId = builder.getStringAttr(getTokenSpelling()); 309 else 310 return emitWrongTokenError("expected attribute name"); 311 312 if (nameId->empty()) 313 return emitError("expected valid attribute name"); 314 315 if (!seenKeys.insert(*nameId).second) 316 return emitError("duplicate key '") 317 << nameId->getValue() << "' in dictionary attribute"; 318 consumeToken(); 319 320 // Lazy load a dialect in the context if there is a possible namespace. 321 auto splitName = nameId->strref().split('.'); 322 if (!splitName.second.empty()) 323 getContext()->getOrLoadDialect(splitName.first); 324 325 // Try to parse the '=' for the attribute value. 326 if (!consumeIf(Token::equal)) { 327 // If there is no '=', we treat this as a unit attribute. 328 attributes.push_back({*nameId, builder.getUnitAttr()}); 329 return success(); 330 } 331 332 auto attr = parseAttribute(); 333 if (!attr) 334 return failure(); 335 attributes.push_back({*nameId, attr}); 336 return success(); 337 }; 338 339 return parseCommaSeparatedList(Delimiter::Braces, parseElt, 340 " in attribute dictionary"); 341 } 342 343 /// Parse a float attribute. 344 Attribute Parser::parseFloatAttr(Type type, bool isNegative) { 345 auto val = getToken().getFloatingPointValue(); 346 if (!val) 347 return (emitError("floating point value too large for attribute"), nullptr); 348 consumeToken(Token::floatliteral); 349 if (!type) { 350 // Default to F64 when no type is specified. 351 if (!consumeIf(Token::colon)) 352 type = builder.getF64Type(); 353 else if (!(type = parseType())) 354 return nullptr; 355 } 356 if (!isa<FloatType>(type)) 357 return (emitError("floating point value not valid for specified type"), 358 nullptr); 359 return FloatAttr::get(type, isNegative ? -*val : *val); 360 } 361 362 /// Construct an APint from a parsed value, a known attribute type and 363 /// sign. 364 static std::optional<APInt> buildAttributeAPInt(Type type, bool isNegative, 365 StringRef spelling) { 366 // Parse the integer value into an APInt that is big enough to hold the value. 367 APInt result; 368 bool isHex = spelling.size() > 1 && spelling[1] == 'x'; 369 if (spelling.getAsInteger(isHex ? 0 : 10, result)) 370 return std::nullopt; 371 372 // Extend or truncate the bitwidth to the right size. 373 unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth 374 : type.getIntOrFloatBitWidth(); 375 376 if (width > result.getBitWidth()) { 377 result = result.zext(width); 378 } else if (width < result.getBitWidth()) { 379 // The parser can return an unnecessarily wide result with leading zeros. 380 // This isn't a problem, but truncating off bits is bad. 381 if (result.countl_zero() < result.getBitWidth() - width) 382 return std::nullopt; 383 384 result = result.trunc(width); 385 } 386 387 if (width == 0) { 388 // 0 bit integers cannot be negative and manipulation of their sign bit will 389 // assert, so short-cut validation here. 390 if (isNegative) 391 return std::nullopt; 392 } else if (isNegative) { 393 // The value is negative, we have an overflow if the sign bit is not set 394 // in the negated apInt. 395 result.negate(); 396 if (!result.isSignBitSet()) 397 return std::nullopt; 398 } else if ((type.isSignedInteger() || type.isIndex()) && 399 result.isSignBitSet()) { 400 // The value is a positive signed integer or index, 401 // we have an overflow if the sign bit is set. 402 return std::nullopt; 403 } 404 405 return result; 406 } 407 408 /// Parse a decimal or a hexadecimal literal, which can be either an integer 409 /// or a float attribute. 410 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { 411 Token tok = getToken(); 412 StringRef spelling = tok.getSpelling(); 413 SMLoc loc = tok.getLoc(); 414 415 consumeToken(Token::integer); 416 if (!type) { 417 // Default to i64 if not type is specified. 418 if (!consumeIf(Token::colon)) 419 type = builder.getIntegerType(64); 420 else if (!(type = parseType())) 421 return nullptr; 422 } 423 424 if (auto floatType = dyn_cast<FloatType>(type)) { 425 std::optional<APFloat> result; 426 if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative, 427 floatType.getFloatSemantics()))) 428 return Attribute(); 429 return FloatAttr::get(floatType, *result); 430 } 431 432 if (!isa<IntegerType, IndexType>(type)) 433 return emitError(loc, "integer literal not valid for specified type"), 434 nullptr; 435 436 if (isNegative && type.isUnsignedInteger()) { 437 emitError(loc, 438 "negative integer literal not valid for unsigned integer type"); 439 return nullptr; 440 } 441 442 std::optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling); 443 if (!apInt) 444 return emitError(loc, "integer constant out of range for attribute"), 445 nullptr; 446 return builder.getIntegerAttr(type, *apInt); 447 } 448 449 //===----------------------------------------------------------------------===// 450 // TensorLiteralParser 451 //===----------------------------------------------------------------------===// 452 453 /// Parse elements values stored within a hex string. On success, the values are 454 /// stored into 'result'. 455 static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, 456 std::string &result) { 457 if (std::optional<std::string> value = tok.getHexStringValue()) { 458 result = std::move(*value); 459 return success(); 460 } 461 return parser.emitError( 462 tok.getLoc(), "expected string containing hex digits starting with `0x`"); 463 } 464 465 namespace { 466 /// This class implements a parser for TensorLiterals. A tensor literal is 467 /// either a single element (e.g, 5) or a multi-dimensional list of elements 468 /// (e.g., [[5, 5]]). 469 class TensorLiteralParser { 470 public: 471 TensorLiteralParser(Parser &p) : p(p) {} 472 473 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser 474 /// may also parse a tensor literal that is store as a hex string. 475 ParseResult parse(bool allowHex); 476 477 /// Build a dense attribute instance with the parsed elements and the given 478 /// shaped type. 479 DenseElementsAttr getAttr(SMLoc loc, ShapedType type); 480 481 ArrayRef<int64_t> getShape() const { return shape; } 482 483 private: 484 /// Get the parsed elements for an integer attribute. 485 ParseResult getIntAttrElements(SMLoc loc, Type eltTy, 486 std::vector<APInt> &intValues); 487 488 /// Get the parsed elements for a float attribute. 489 ParseResult getFloatAttrElements(SMLoc loc, FloatType eltTy, 490 std::vector<APFloat> &floatValues); 491 492 /// Build a Dense String attribute for the given type. 493 DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy); 494 495 /// Build a Dense attribute with hex data for the given type. 496 DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type); 497 498 /// Parse a single element, returning failure if it isn't a valid element 499 /// literal. For example: 500 /// parseElement(1) -> Success, 1 501 /// parseElement([1]) -> Failure 502 ParseResult parseElement(); 503 504 /// Parse a list of either lists or elements, returning the dimensions of the 505 /// parsed sub-tensors in dims. For example: 506 /// parseList([1, 2, 3]) -> Success, [3] 507 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 508 /// parseList([[1, 2], 3]) -> Failure 509 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 510 ParseResult parseList(SmallVectorImpl<int64_t> &dims); 511 512 /// Parse a literal that was printed as a hex string. 513 ParseResult parseHexElements(); 514 515 Parser &p; 516 517 /// The shape inferred from the parsed elements. 518 SmallVector<int64_t, 4> shape; 519 520 /// Storage used when parsing elements, this is a pair of <is_negated, token>. 521 std::vector<std::pair<bool, Token>> storage; 522 523 /// Storage used when parsing elements that were stored as hex values. 524 std::optional<Token> hexStorage; 525 }; 526 } // namespace 527 528 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser 529 /// may also parse a tensor literal that is store as a hex string. 530 ParseResult TensorLiteralParser::parse(bool allowHex) { 531 // If hex is allowed, check for a string literal. 532 if (allowHex && p.getToken().is(Token::string)) { 533 hexStorage = p.getToken(); 534 p.consumeToken(Token::string); 535 return success(); 536 } 537 // Otherwise, parse a list or an individual element. 538 if (p.getToken().is(Token::l_square)) 539 return parseList(shape); 540 return parseElement(); 541 } 542 543 /// Build a dense attribute instance with the parsed elements and the given 544 /// shaped type. 545 DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { 546 Type eltType = type.getElementType(); 547 548 // Check to see if we parse the literal from a hex string. 549 if (hexStorage && 550 (eltType.isIntOrIndexOrFloat() || isa<ComplexType>(eltType))) 551 return getHexAttr(loc, type); 552 553 // Check that the parsed storage size has the same number of elements to the 554 // type, or is a known splat. 555 if (!shape.empty() && getShape() != type.getShape()) { 556 p.emitError(loc) << "inferred shape of elements literal ([" << getShape() 557 << "]) does not match type ([" << type.getShape() << "])"; 558 return nullptr; 559 } 560 561 // Handle the case where no elements were parsed. 562 if (!hexStorage && storage.empty() && type.getNumElements()) { 563 p.emitError(loc) << "parsed zero elements, but type (" << type 564 << ") expected at least 1"; 565 return nullptr; 566 } 567 568 // Handle complex types in the specific element type cases below. 569 bool isComplex = false; 570 if (ComplexType complexTy = dyn_cast<ComplexType>(eltType)) { 571 eltType = complexTy.getElementType(); 572 isComplex = true; 573 } 574 575 // Handle integer and index types. 576 if (eltType.isIntOrIndex()) { 577 std::vector<APInt> intValues; 578 if (failed(getIntAttrElements(loc, eltType, intValues))) 579 return nullptr; 580 if (isComplex) { 581 // If this is a complex, treat the parsed values as complex values. 582 auto complexData = llvm::ArrayRef( 583 reinterpret_cast<std::complex<APInt> *>(intValues.data()), 584 intValues.size() / 2); 585 return DenseElementsAttr::get(type, complexData); 586 } 587 return DenseElementsAttr::get(type, intValues); 588 } 589 // Handle floating point types. 590 if (FloatType floatTy = dyn_cast<FloatType>(eltType)) { 591 std::vector<APFloat> floatValues; 592 if (failed(getFloatAttrElements(loc, floatTy, floatValues))) 593 return nullptr; 594 if (isComplex) { 595 // If this is a complex, treat the parsed values as complex values. 596 auto complexData = llvm::ArrayRef( 597 reinterpret_cast<std::complex<APFloat> *>(floatValues.data()), 598 floatValues.size() / 2); 599 return DenseElementsAttr::get(type, complexData); 600 } 601 return DenseElementsAttr::get(type, floatValues); 602 } 603 604 // Other types are assumed to be string representations. 605 return getStringAttr(loc, type, type.getElementType()); 606 } 607 608 /// Build a Dense Integer attribute for the given type. 609 ParseResult 610 TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy, 611 std::vector<APInt> &intValues) { 612 intValues.reserve(storage.size()); 613 bool isUintType = eltTy.isUnsignedInteger(); 614 for (const auto &signAndToken : storage) { 615 bool isNegative = signAndToken.first; 616 const Token &token = signAndToken.second; 617 auto tokenLoc = token.getLoc(); 618 619 if (isNegative && isUintType) { 620 return p.emitError(tokenLoc) 621 << "expected unsigned integer elements, but parsed negative value"; 622 } 623 624 // Check to see if floating point values were parsed. 625 if (token.is(Token::floatliteral)) { 626 return p.emitError(tokenLoc) 627 << "expected integer elements, but parsed floating-point"; 628 } 629 630 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && 631 "unexpected token type"); 632 if (token.isAny(Token::kw_true, Token::kw_false)) { 633 if (!eltTy.isInteger(1)) { 634 return p.emitError(tokenLoc) 635 << "expected i1 type for 'true' or 'false' values"; 636 } 637 APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false); 638 intValues.push_back(apInt); 639 continue; 640 } 641 642 // Create APInt values for each element with the correct bitwidth. 643 std::optional<APInt> apInt = 644 buildAttributeAPInt(eltTy, isNegative, token.getSpelling()); 645 if (!apInt) 646 return p.emitError(tokenLoc, "integer constant out of range for type"); 647 intValues.push_back(*apInt); 648 } 649 return success(); 650 } 651 652 /// Build a Dense Float attribute for the given type. 653 ParseResult 654 TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy, 655 std::vector<APFloat> &floatValues) { 656 floatValues.reserve(storage.size()); 657 for (const auto &signAndToken : storage) { 658 bool isNegative = signAndToken.first; 659 const Token &token = signAndToken.second; 660 std::optional<APFloat> result; 661 if (failed(p.parseFloatFromLiteral(result, token, isNegative, 662 eltTy.getFloatSemantics()))) 663 return failure(); 664 floatValues.push_back(*result); 665 } 666 return success(); 667 } 668 669 /// Build a Dense String attribute for the given type. 670 DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type, 671 Type eltTy) { 672 if (hexStorage.has_value()) { 673 auto stringValue = hexStorage->getStringValue(); 674 return DenseStringElementsAttr::get(type, {stringValue}); 675 } 676 677 std::vector<std::string> stringValues; 678 std::vector<StringRef> stringRefValues; 679 stringValues.reserve(storage.size()); 680 stringRefValues.reserve(storage.size()); 681 682 for (auto val : storage) { 683 stringValues.push_back(val.second.getStringValue()); 684 stringRefValues.emplace_back(stringValues.back()); 685 } 686 687 return DenseStringElementsAttr::get(type, stringRefValues); 688 } 689 690 /// Build a Dense attribute with hex data for the given type. 691 DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) { 692 Type elementType = type.getElementType(); 693 if (!elementType.isIntOrIndexOrFloat() && !isa<ComplexType>(elementType)) { 694 p.emitError(loc) 695 << "expected floating-point, integer, or complex element type, got " 696 << elementType; 697 return nullptr; 698 } 699 700 std::string data; 701 if (parseElementAttrHexValues(p, *hexStorage, data)) 702 return nullptr; 703 704 ArrayRef<char> rawData(data.data(), data.size()); 705 bool detectedSplat = false; 706 if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) { 707 p.emitError(loc) << "elements hex data size is invalid for provided type: " 708 << type; 709 return nullptr; 710 } 711 712 if (llvm::endianness::native == llvm::endianness::big) { 713 // Convert endianess in big-endian(BE) machines. `rawData` is 714 // little-endian(LE) because HEX in raw data of dense element attribute 715 // is always LE format. It is converted into BE here to be used in BE 716 // machines. 717 SmallVector<char, 64> outDataVec(rawData.size()); 718 MutableArrayRef<char> convRawData(outDataVec); 719 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 720 rawData, convRawData, type); 721 return DenseElementsAttr::getFromRawBuffer(type, convRawData); 722 } 723 724 return DenseElementsAttr::getFromRawBuffer(type, rawData); 725 } 726 727 ParseResult TensorLiteralParser::parseElement() { 728 switch (p.getToken().getKind()) { 729 // Parse a boolean element. 730 case Token::kw_true: 731 case Token::kw_false: 732 case Token::floatliteral: 733 case Token::integer: 734 storage.emplace_back(/*isNegative=*/false, p.getToken()); 735 p.consumeToken(); 736 break; 737 738 // Parse a signed integer or a negative floating-point element. 739 case Token::minus: 740 p.consumeToken(Token::minus); 741 if (!p.getToken().isAny(Token::floatliteral, Token::integer)) 742 return p.emitError("expected integer or floating point literal"); 743 storage.emplace_back(/*isNegative=*/true, p.getToken()); 744 p.consumeToken(); 745 break; 746 747 case Token::string: 748 storage.emplace_back(/*isNegative=*/false, p.getToken()); 749 p.consumeToken(); 750 break; 751 752 // Parse a complex element of the form '(' element ',' element ')'. 753 case Token::l_paren: 754 p.consumeToken(Token::l_paren); 755 if (parseElement() || 756 p.parseToken(Token::comma, "expected ',' between complex elements") || 757 parseElement() || 758 p.parseToken(Token::r_paren, "expected ')' after complex elements")) 759 return failure(); 760 break; 761 762 default: 763 return p.emitError("expected element literal of primitive type"); 764 } 765 766 return success(); 767 } 768 769 /// Parse a list of either lists or elements, returning the dimensions of the 770 /// parsed sub-tensors in dims. For example: 771 /// parseList([1, 2, 3]) -> Success, [3] 772 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 773 /// parseList([[1, 2], 3]) -> Failure 774 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 775 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) { 776 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims, 777 const SmallVectorImpl<int64_t> &newDims) -> ParseResult { 778 if (prevDims == newDims) 779 return success(); 780 return p.emitError("tensor literal is invalid; ranks are not consistent " 781 "between elements"); 782 }; 783 784 bool first = true; 785 SmallVector<int64_t, 4> newDims; 786 unsigned size = 0; 787 auto parseOneElement = [&]() -> ParseResult { 788 SmallVector<int64_t, 4> thisDims; 789 if (p.getToken().getKind() == Token::l_square) { 790 if (parseList(thisDims)) 791 return failure(); 792 } else if (parseElement()) { 793 return failure(); 794 } 795 ++size; 796 if (!first) 797 return checkDims(newDims, thisDims); 798 newDims = thisDims; 799 first = false; 800 return success(); 801 }; 802 if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOneElement)) 803 return failure(); 804 805 // Return the sublists' dimensions with 'size' prepended. 806 dims.clear(); 807 dims.push_back(size); 808 dims.append(newDims.begin(), newDims.end()); 809 return success(); 810 } 811 812 //===----------------------------------------------------------------------===// 813 // DenseArrayAttr Parser 814 //===----------------------------------------------------------------------===// 815 816 namespace { 817 /// A generic dense array element parser. It parsers integer and floating point 818 /// elements. 819 class DenseArrayElementParser { 820 public: 821 explicit DenseArrayElementParser(Type type) : type(type) {} 822 823 /// Parse an integer element. 824 ParseResult parseIntegerElement(Parser &p); 825 826 /// Parse a floating point element. 827 ParseResult parseFloatElement(Parser &p); 828 829 /// Convert the current contents to a dense array. 830 DenseArrayAttr getAttr() { return DenseArrayAttr::get(type, size, rawData); } 831 832 private: 833 /// Append the raw data of an APInt to the result. 834 void append(const APInt &data); 835 836 /// The array element type. 837 Type type; 838 /// The resultant byte array representing the contents of the array. 839 std::vector<char> rawData; 840 /// The number of elements in the array. 841 int64_t size = 0; 842 }; 843 } // namespace 844 845 void DenseArrayElementParser::append(const APInt &data) { 846 if (data.getBitWidth()) { 847 assert(data.getBitWidth() % 8 == 0); 848 unsigned byteSize = data.getBitWidth() / 8; 849 size_t offset = rawData.size(); 850 rawData.insert(rawData.end(), byteSize, 0); 851 llvm::StoreIntToMemory( 852 data, reinterpret_cast<uint8_t *>(rawData.data() + offset), byteSize); 853 } 854 ++size; 855 } 856 857 ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) { 858 bool isNegative = p.consumeIf(Token::minus); 859 860 // Parse an integer literal as an APInt. 861 std::optional<APInt> value; 862 StringRef spelling = p.getToken().getSpelling(); 863 if (p.getToken().isAny(Token::kw_true, Token::kw_false)) { 864 if (!type.isInteger(1)) 865 return p.emitError("expected i1 type for 'true' or 'false' values"); 866 value = APInt(/*numBits=*/8, p.getToken().is(Token::kw_true), 867 !type.isUnsignedInteger()); 868 p.consumeToken(); 869 } else if (p.consumeIf(Token::integer)) { 870 value = buildAttributeAPInt(type, isNegative, spelling); 871 if (!value) 872 return p.emitError("integer constant out of range"); 873 } else { 874 return p.emitError("expected integer literal"); 875 } 876 append(*value); 877 return success(); 878 } 879 880 ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) { 881 bool isNegative = p.consumeIf(Token::minus); 882 Token token = p.getToken(); 883 std::optional<APFloat> fromIntLit; 884 if (failed( 885 p.parseFloatFromLiteral(fromIntLit, token, isNegative, 886 cast<FloatType>(type).getFloatSemantics()))) 887 return failure(); 888 p.consumeToken(); 889 append(fromIntLit->bitcastToAPInt()); 890 return success(); 891 } 892 893 /// Parse a dense array attribute. 894 Attribute Parser::parseDenseArrayAttr(Type attrType) { 895 consumeToken(Token::kw_array); 896 if (parseToken(Token::less, "expected '<' after 'array'")) 897 return {}; 898 899 SMLoc typeLoc = getToken().getLoc(); 900 Type eltType = parseType(); 901 if (!eltType) { 902 emitError(typeLoc, "expected an integer or floating point type"); 903 return {}; 904 } 905 906 // Only bool or integer and floating point elements divisible by bytes are 907 // supported. 908 if (!eltType.isIntOrIndexOrFloat()) { 909 emitError(typeLoc, "expected integer or float type, got: ") << eltType; 910 return {}; 911 } 912 if (!eltType.isInteger(1) && eltType.getIntOrFloatBitWidth() % 8 != 0) { 913 emitError(typeLoc, "element type bitwidth must be a multiple of 8"); 914 return {}; 915 } 916 917 // Check for empty list. 918 if (consumeIf(Token::greater)) 919 return DenseArrayAttr::get(eltType, 0, {}); 920 921 if (parseToken(Token::colon, "expected ':' after dense array type")) 922 return {}; 923 924 DenseArrayElementParser eltParser(eltType); 925 if (eltType.isIntOrIndex()) { 926 if (parseCommaSeparatedList( 927 [&] { return eltParser.parseIntegerElement(*this); })) 928 return {}; 929 } else { 930 if (parseCommaSeparatedList( 931 [&] { return eltParser.parseFloatElement(*this); })) 932 return {}; 933 } 934 if (parseToken(Token::greater, "expected '>' to close an array attribute")) 935 return {}; 936 return eltParser.getAttr(); 937 } 938 939 /// Parse a dense elements attribute. 940 Attribute Parser::parseDenseElementsAttr(Type attrType) { 941 auto attribLoc = getToken().getLoc(); 942 consumeToken(Token::kw_dense); 943 if (parseToken(Token::less, "expected '<' after 'dense'")) 944 return nullptr; 945 946 // Parse the literal data if necessary. 947 TensorLiteralParser literalParser(*this); 948 if (!consumeIf(Token::greater)) { 949 if (literalParser.parse(/*allowHex=*/true) || 950 parseToken(Token::greater, "expected '>'")) 951 return nullptr; 952 } 953 954 // If the type is specified `parseElementsLiteralType` will not parse a type. 955 // Use the attribute location as the location for error reporting in that 956 // case. 957 auto loc = attrType ? attribLoc : getToken().getLoc(); 958 auto type = parseElementsLiteralType(attrType); 959 if (!type) 960 return nullptr; 961 return literalParser.getAttr(loc, type); 962 } 963 964 Attribute Parser::parseDenseResourceElementsAttr(Type attrType) { 965 auto loc = getToken().getLoc(); 966 consumeToken(Token::kw_dense_resource); 967 if (parseToken(Token::less, "expected '<' after 'dense_resource'")) 968 return nullptr; 969 970 // Parse the resource handle. 971 FailureOr<AsmDialectResourceHandle> rawHandle = 972 parseResourceHandle(getContext()->getLoadedDialect<BuiltinDialect>()); 973 if (failed(rawHandle) || parseToken(Token::greater, "expected '>'")) 974 return nullptr; 975 976 auto *handle = dyn_cast<DenseResourceElementsHandle>(&*rawHandle); 977 if (!handle) 978 return emitError(loc, "invalid `dense_resource` handle type"), nullptr; 979 980 // Parse the type of the attribute if the user didn't provide one. 981 SMLoc typeLoc = loc; 982 if (!attrType) { 983 typeLoc = getToken().getLoc(); 984 if (parseToken(Token::colon, "expected ':'") || !(attrType = parseType())) 985 return nullptr; 986 } 987 988 ShapedType shapedType = dyn_cast<ShapedType>(attrType); 989 if (!shapedType) { 990 emitError(typeLoc, "`dense_resource` expected a shaped type"); 991 return nullptr; 992 } 993 994 return DenseResourceElementsAttr::get(shapedType, *handle); 995 } 996 997 /// Shaped type for elements attribute. 998 /// 999 /// elements-literal-type ::= vector-type | ranked-tensor-type 1000 /// 1001 /// This method also checks the type has static shape. 1002 ShapedType Parser::parseElementsLiteralType(Type type) { 1003 // If the user didn't provide a type, parse the colon type for the literal. 1004 if (!type) { 1005 if (parseToken(Token::colon, "expected ':'")) 1006 return nullptr; 1007 if (!(type = parseType())) 1008 return nullptr; 1009 } 1010 1011 auto sType = dyn_cast<ShapedType>(type); 1012 if (!sType) { 1013 emitError("elements literal must be a shaped type"); 1014 return nullptr; 1015 } 1016 1017 if (!sType.hasStaticShape()) 1018 return (emitError("elements literal type must have static shape"), nullptr); 1019 1020 return sType; 1021 } 1022 1023 /// Parse a sparse elements attribute. 1024 Attribute Parser::parseSparseElementsAttr(Type attrType) { 1025 SMLoc loc = getToken().getLoc(); 1026 consumeToken(Token::kw_sparse); 1027 if (parseToken(Token::less, "Expected '<' after 'sparse'")) 1028 return nullptr; 1029 1030 // Check for the case where all elements are sparse. The indices are 1031 // represented by a 2-dimensional shape where the second dimension is the rank 1032 // of the type. 1033 Type indiceEltType = builder.getIntegerType(64); 1034 if (consumeIf(Token::greater)) { 1035 ShapedType type = parseElementsLiteralType(attrType); 1036 if (!type) 1037 return nullptr; 1038 1039 // Construct the sparse elements attr using zero element indice/value 1040 // attributes. 1041 ShapedType indicesType = 1042 RankedTensorType::get({0, type.getRank()}, indiceEltType); 1043 ShapedType valuesType = RankedTensorType::get({0}, type.getElementType()); 1044 return getChecked<SparseElementsAttr>( 1045 loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()), 1046 DenseElementsAttr::get(valuesType, ArrayRef<Attribute>())); 1047 } 1048 1049 /// Parse the indices. We don't allow hex values here as we may need to use 1050 /// the inferred shape. 1051 auto indicesLoc = getToken().getLoc(); 1052 TensorLiteralParser indiceParser(*this); 1053 if (indiceParser.parse(/*allowHex=*/false)) 1054 return nullptr; 1055 1056 if (parseToken(Token::comma, "expected ','")) 1057 return nullptr; 1058 1059 /// Parse the values. 1060 auto valuesLoc = getToken().getLoc(); 1061 TensorLiteralParser valuesParser(*this); 1062 if (valuesParser.parse(/*allowHex=*/true)) 1063 return nullptr; 1064 1065 if (parseToken(Token::greater, "expected '>'")) 1066 return nullptr; 1067 1068 auto type = parseElementsLiteralType(attrType); 1069 if (!type) 1070 return nullptr; 1071 1072 // If the indices are a splat, i.e. the literal parser parsed an element and 1073 // not a list, we set the shape explicitly. The indices are represented by a 1074 // 2-dimensional shape where the second dimension is the rank of the type. 1075 // Given that the parsed indices is a splat, we know that we only have one 1076 // indice and thus one for the first dimension. 1077 ShapedType indicesType; 1078 if (indiceParser.getShape().empty()) { 1079 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); 1080 } else { 1081 // Otherwise, set the shape to the one parsed by the literal parser. 1082 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); 1083 } 1084 auto indices = indiceParser.getAttr(indicesLoc, indicesType); 1085 1086 // If the values are a splat, set the shape explicitly based on the number of 1087 // indices. The number of indices is encoded in the first dimension of the 1088 // indice shape type. 1089 auto valuesEltType = type.getElementType(); 1090 ShapedType valuesType = 1091 valuesParser.getShape().empty() 1092 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) 1093 : RankedTensorType::get(valuesParser.getShape(), valuesEltType); 1094 auto values = valuesParser.getAttr(valuesLoc, valuesType); 1095 1096 // Build the sparse elements attribute by the indices and values. 1097 return getChecked<SparseElementsAttr>(loc, type, indices, values); 1098 } 1099 1100 Attribute Parser::parseStridedLayoutAttr() { 1101 // Callback for error emissing at the keyword token location. 1102 llvm::SMLoc loc = getToken().getLoc(); 1103 auto errorEmitter = [&] { return emitError(loc); }; 1104 1105 consumeToken(Token::kw_strided); 1106 if (failed(parseToken(Token::less, "expected '<' after 'strided'")) || 1107 failed(parseToken(Token::l_square, "expected '['"))) 1108 return nullptr; 1109 1110 // Parses either an integer token or a question mark token. Reports an error 1111 // and returns std::nullopt if the current token is neither. The integer token 1112 // must fit into int64_t limits. 1113 auto parseStrideOrOffset = [&]() -> std::optional<int64_t> { 1114 if (consumeIf(Token::question)) 1115 return ShapedType::kDynamic; 1116 1117 SMLoc loc = getToken().getLoc(); 1118 auto emitWrongTokenError = [&] { 1119 emitError(loc, "expected a 64-bit signed integer or '?'"); 1120 return std::nullopt; 1121 }; 1122 1123 bool negative = consumeIf(Token::minus); 1124 1125 if (getToken().is(Token::integer)) { 1126 std::optional<uint64_t> value = getToken().getUInt64IntegerValue(); 1127 if (!value || 1128 *value > static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) 1129 return emitWrongTokenError(); 1130 consumeToken(); 1131 auto result = static_cast<int64_t>(*value); 1132 if (negative) 1133 result = -result; 1134 1135 return result; 1136 } 1137 1138 return emitWrongTokenError(); 1139 }; 1140 1141 // Parse strides. 1142 SmallVector<int64_t> strides; 1143 if (!getToken().is(Token::r_square)) { 1144 do { 1145 std::optional<int64_t> stride = parseStrideOrOffset(); 1146 if (!stride) 1147 return nullptr; 1148 strides.push_back(*stride); 1149 } while (consumeIf(Token::comma)); 1150 } 1151 1152 if (failed(parseToken(Token::r_square, "expected ']'"))) 1153 return nullptr; 1154 1155 // Fast path in absence of offset. 1156 if (consumeIf(Token::greater)) { 1157 if (failed(StridedLayoutAttr::verify(errorEmitter, 1158 /*offset=*/0, strides))) 1159 return nullptr; 1160 return StridedLayoutAttr::get(getContext(), /*offset=*/0, strides); 1161 } 1162 1163 if (failed(parseToken(Token::comma, "expected ','")) || 1164 failed(parseToken(Token::kw_offset, "expected 'offset' after comma")) || 1165 failed(parseToken(Token::colon, "expected ':' after 'offset'"))) 1166 return nullptr; 1167 1168 std::optional<int64_t> offset = parseStrideOrOffset(); 1169 if (!offset || failed(parseToken(Token::greater, "expected '>'"))) 1170 return nullptr; 1171 1172 if (failed(StridedLayoutAttr::verify(errorEmitter, *offset, strides))) 1173 return nullptr; 1174 return StridedLayoutAttr::get(getContext(), *offset, strides); 1175 // return getChecked<StridedLayoutAttr>(loc,getContext(), *offset, strides); 1176 } 1177 1178 /// Parse a distinct attribute. 1179 /// 1180 /// distinct-attribute ::= `distinct` 1181 /// `[` integer-literal `]<` attribute-value `>` 1182 /// 1183 Attribute Parser::parseDistinctAttr(Type type) { 1184 SMLoc loc = getToken().getLoc(); 1185 consumeToken(Token::kw_distinct); 1186 if (parseToken(Token::l_square, "expected '[' after 'distinct'")) 1187 return {}; 1188 1189 // Parse the distinct integer identifier. 1190 Token token = getToken(); 1191 if (parseToken(Token::integer, "expected distinct ID")) 1192 return {}; 1193 std::optional<uint64_t> value = token.getUInt64IntegerValue(); 1194 if (!value) { 1195 emitError("expected an unsigned 64-bit integer"); 1196 return {}; 1197 } 1198 1199 // Parse the referenced attribute. 1200 if (parseToken(Token::r_square, "expected ']' to close distinct ID") || 1201 parseToken(Token::less, "expected '<' after distinct ID")) 1202 return {}; 1203 1204 Attribute referencedAttr; 1205 if (getToken().is(Token::greater)) { 1206 consumeToken(); 1207 referencedAttr = builder.getUnitAttr(); 1208 } else { 1209 referencedAttr = parseAttribute(type); 1210 if (!referencedAttr) { 1211 emitError("expected attribute"); 1212 return {}; 1213 } 1214 1215 if (parseToken(Token::greater, "expected '>' to close distinct attribute")) 1216 return {}; 1217 } 1218 1219 // Add the distinct attribute to the parser state, if it has not been parsed 1220 // before. Otherwise, check if the parsed reference attribute matches the one 1221 // found in the parser state. 1222 DenseMap<uint64_t, DistinctAttr> &distinctAttrs = 1223 state.symbols.distinctAttributes; 1224 auto it = distinctAttrs.find(*value); 1225 if (it == distinctAttrs.end()) { 1226 DistinctAttr distinctAttr = DistinctAttr::create(referencedAttr); 1227 it = distinctAttrs.try_emplace(*value, distinctAttr).first; 1228 } else if (it->getSecond().getReferencedAttr() != referencedAttr) { 1229 emitError(loc, "referenced attribute does not match previous definition: ") 1230 << it->getSecond().getReferencedAttr(); 1231 return {}; 1232 } 1233 1234 return it->getSecond(); 1235 } 1236