1c60b897dSRiver Riddle //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===// 2c60b897dSRiver Riddle // 3c60b897dSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4c60b897dSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5c60b897dSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6c60b897dSRiver Riddle // 7c60b897dSRiver Riddle //===----------------------------------------------------------------------===// 8c60b897dSRiver Riddle // 9c60b897dSRiver Riddle // This file implements the parser for the MLIR Types. 10c60b897dSRiver Riddle // 11c60b897dSRiver Riddle //===----------------------------------------------------------------------===// 12c60b897dSRiver Riddle 13c60b897dSRiver Riddle #include "Parser.h" 14c60b897dSRiver Riddle 15c60b897dSRiver Riddle #include "AsmParserImpl.h" 16c60b897dSRiver Riddle #include "mlir/AsmParser/AsmParserState.h" 17c60b897dSRiver Riddle #include "mlir/IR/AffineMap.h" 18519847feSAlex Zinenko #include "mlir/IR/BuiltinAttributes.h" 19995ab929SRiver Riddle #include "mlir/IR/BuiltinDialect.h" 20c60b897dSRiver Riddle #include "mlir/IR/BuiltinTypes.h" 21c60b897dSRiver Riddle #include "mlir/IR/DialectImplementation.h" 22995ab929SRiver Riddle #include "mlir/IR/DialectResourceBlobManager.h" 23c60b897dSRiver Riddle #include "mlir/IR/IntegerSet.h" 24c60b897dSRiver Riddle #include "llvm/ADT/StringExtras.h" 25c60b897dSRiver Riddle #include "llvm/Support/Endian.h" 26d8be2081SKazu Hirata #include <optional> 27c60b897dSRiver Riddle 28c60b897dSRiver Riddle using namespace mlir; 29c60b897dSRiver Riddle using namespace mlir::detail; 30c60b897dSRiver Riddle 31c60b897dSRiver Riddle /// Parse an arbitrary attribute. 32c60b897dSRiver Riddle /// 33c60b897dSRiver Riddle /// attribute-value ::= `unit` 34c60b897dSRiver Riddle /// | bool-literal 35c60b897dSRiver Riddle /// | integer-literal (`:` (index-type | integer-type))? 36c60b897dSRiver Riddle /// | float-literal (`:` float-type)? 37c60b897dSRiver Riddle /// | string-literal (`:` type)? 38c60b897dSRiver Riddle /// | type 39c60b897dSRiver Riddle /// | `[` `:` (integer-type | float-type) tensor-literal `]` 40c60b897dSRiver Riddle /// | `[` (attribute-value (`,` attribute-value)*)? `]` 41c60b897dSRiver Riddle /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` 42c60b897dSRiver Riddle /// | symbol-ref-id (`::` symbol-ref-id)* 43c60b897dSRiver Riddle /// | `dense` `<` tensor-literal `>` `:` 44c60b897dSRiver Riddle /// (tensor-type | vector-type) 45c60b897dSRiver Riddle /// | `sparse` `<` attribute-value `,` attribute-value `>` 46c60b897dSRiver Riddle /// `:` (tensor-type | vector-type) 47519847feSAlex Zinenko /// | `strided` `<` `[` comma-separated-int-or-question `]` 48519847feSAlex Zinenko /// (`,` `offset` `:` integer-literal)? `>` 49728a8d5aSTobias Gysi /// | distinct-attribute 50c60b897dSRiver Riddle /// | extended-attribute 51c60b897dSRiver Riddle /// 52c60b897dSRiver Riddle Attribute Parser::parseAttribute(Type type) { 53c60b897dSRiver Riddle switch (getToken().getKind()) { 54c60b897dSRiver Riddle // Parse an AffineMap or IntegerSet attribute. 55c60b897dSRiver Riddle case Token::kw_affine_map: { 56c60b897dSRiver Riddle consumeToken(Token::kw_affine_map); 57c60b897dSRiver Riddle 58c60b897dSRiver Riddle AffineMap map; 59c60b897dSRiver Riddle if (parseToken(Token::less, "expected '<' in affine map") || 60c60b897dSRiver Riddle parseAffineMapReference(map) || 61c60b897dSRiver Riddle parseToken(Token::greater, "expected '>' in affine map")) 62c60b897dSRiver Riddle return Attribute(); 63c60b897dSRiver Riddle return AffineMapAttr::get(map); 64c60b897dSRiver Riddle } 65c60b897dSRiver Riddle case Token::kw_affine_set: { 66c60b897dSRiver Riddle consumeToken(Token::kw_affine_set); 67c60b897dSRiver Riddle 68c60b897dSRiver Riddle IntegerSet set; 69c60b897dSRiver Riddle if (parseToken(Token::less, "expected '<' in integer set") || 70c60b897dSRiver Riddle parseIntegerSetReference(set) || 71c60b897dSRiver Riddle parseToken(Token::greater, "expected '>' in integer set")) 72c60b897dSRiver Riddle return Attribute(); 73c60b897dSRiver Riddle return IntegerSetAttr::get(set); 74c60b897dSRiver Riddle } 75c60b897dSRiver Riddle 76c60b897dSRiver Riddle // Parse an array attribute. 77c60b897dSRiver Riddle case Token::l_square: { 78c60b897dSRiver Riddle consumeToken(Token::l_square); 79c60b897dSRiver Riddle SmallVector<Attribute, 4> elements; 80c60b897dSRiver Riddle auto parseElt = [&]() -> ParseResult { 81c60b897dSRiver Riddle elements.push_back(parseAttribute()); 82c60b897dSRiver Riddle return elements.back() ? success() : failure(); 83c60b897dSRiver Riddle }; 84c60b897dSRiver Riddle 85c60b897dSRiver Riddle if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) 86c60b897dSRiver Riddle return nullptr; 87c60b897dSRiver Riddle return builder.getArrayAttr(elements); 88c60b897dSRiver Riddle } 89c60b897dSRiver Riddle 90c60b897dSRiver Riddle // Parse a boolean attribute. 91c60b897dSRiver Riddle case Token::kw_false: 92c60b897dSRiver Riddle consumeToken(Token::kw_false); 93c60b897dSRiver Riddle return builder.getBoolAttr(false); 94c60b897dSRiver Riddle case Token::kw_true: 95c60b897dSRiver Riddle consumeToken(Token::kw_true); 96c60b897dSRiver Riddle return builder.getBoolAttr(true); 97c60b897dSRiver Riddle 98c60b897dSRiver Riddle // Parse a dense elements attribute. 99c60b897dSRiver Riddle case Token::kw_dense: 100c60b897dSRiver Riddle return parseDenseElementsAttr(type); 101c60b897dSRiver Riddle 102995ab929SRiver Riddle // Parse a dense resource elements attribute. 103995ab929SRiver Riddle case Token::kw_dense_resource: 104995ab929SRiver Riddle return parseDenseResourceElementsAttr(type); 105995ab929SRiver Riddle 1062092d143SJeff Niu // Parse a dense array attribute. 1072092d143SJeff Niu case Token::kw_array: 1082092d143SJeff Niu return parseDenseArrayAttr(type); 1092092d143SJeff Niu 110c60b897dSRiver Riddle // Parse a dictionary attribute. 111c60b897dSRiver Riddle case Token::l_brace: { 112c60b897dSRiver Riddle NamedAttrList elements; 113c60b897dSRiver Riddle if (parseAttributeDict(elements)) 114c60b897dSRiver Riddle return nullptr; 115c60b897dSRiver Riddle return elements.getDictionary(getContext()); 116c60b897dSRiver Riddle } 117c60b897dSRiver Riddle 118c60b897dSRiver Riddle // Parse an extended attribute, i.e. alias or dialect attribute. 119c60b897dSRiver Riddle case Token::hash_identifier: 120c60b897dSRiver Riddle return parseExtendedAttr(type); 121c60b897dSRiver Riddle 122c60b897dSRiver Riddle // Parse floating point and integer attributes. 123c60b897dSRiver Riddle case Token::floatliteral: 124c60b897dSRiver Riddle return parseFloatAttr(type, /*isNegative=*/false); 125c60b897dSRiver Riddle case Token::integer: 126c60b897dSRiver Riddle return parseDecOrHexAttr(type, /*isNegative=*/false); 127c60b897dSRiver Riddle case Token::minus: { 128c60b897dSRiver Riddle consumeToken(Token::minus); 129c60b897dSRiver Riddle if (getToken().is(Token::integer)) 130c60b897dSRiver Riddle return parseDecOrHexAttr(type, /*isNegative=*/true); 131c60b897dSRiver Riddle if (getToken().is(Token::floatliteral)) 132c60b897dSRiver Riddle return parseFloatAttr(type, /*isNegative=*/true); 133c60b897dSRiver Riddle 134c60b897dSRiver Riddle return (emitWrongTokenError( 135c60b897dSRiver Riddle "expected constant integer or floating point value"), 136c60b897dSRiver Riddle nullptr); 137c60b897dSRiver Riddle } 138c60b897dSRiver Riddle 139c60b897dSRiver Riddle // Parse a location attribute. 140c60b897dSRiver Riddle case Token::kw_loc: { 141c60b897dSRiver Riddle consumeToken(Token::kw_loc); 142c60b897dSRiver Riddle 143c60b897dSRiver Riddle LocationAttr locAttr; 144c60b897dSRiver Riddle if (parseToken(Token::l_paren, "expected '(' in inline location") || 145c60b897dSRiver Riddle parseLocationInstance(locAttr) || 146c60b897dSRiver Riddle parseToken(Token::r_paren, "expected ')' in inline location")) 147c60b897dSRiver Riddle return Attribute(); 148c60b897dSRiver Riddle return locAttr; 149c60b897dSRiver Riddle } 150c60b897dSRiver Riddle 151c60b897dSRiver Riddle // Parse a sparse elements attribute. 152c60b897dSRiver Riddle case Token::kw_sparse: 153c60b897dSRiver Riddle return parseSparseElementsAttr(type); 154c60b897dSRiver Riddle 155519847feSAlex Zinenko // Parse a strided layout attribute. 156519847feSAlex Zinenko case Token::kw_strided: 157519847feSAlex Zinenko return parseStridedLayoutAttr(); 158519847feSAlex Zinenko 159728a8d5aSTobias Gysi // Parse a distinct attribute. 160728a8d5aSTobias Gysi case Token::kw_distinct: 161728a8d5aSTobias Gysi return parseDistinctAttr(type); 162728a8d5aSTobias Gysi 163c60b897dSRiver Riddle // Parse a string attribute. 164c60b897dSRiver Riddle case Token::string: { 165c60b897dSRiver Riddle auto val = getToken().getStringValue(); 166c60b897dSRiver Riddle consumeToken(Token::string); 167c60b897dSRiver Riddle // Parse the optional trailing colon type if one wasn't explicitly provided. 168c60b897dSRiver Riddle if (!type && consumeIf(Token::colon) && !(type = parseType())) 169c60b897dSRiver Riddle return Attribute(); 170c60b897dSRiver Riddle 171c60b897dSRiver Riddle return type ? StringAttr::get(val, type) 172c60b897dSRiver Riddle : StringAttr::get(getContext(), val); 173c60b897dSRiver Riddle } 174c60b897dSRiver Riddle 175c60b897dSRiver Riddle // Parse a symbol reference attribute. 176c60b897dSRiver Riddle case Token::at_identifier: { 177c60b897dSRiver Riddle // When populating the parser state, this is a list of locations for all of 178c60b897dSRiver Riddle // the nested references. 179c60b897dSRiver Riddle SmallVector<SMRange> referenceLocations; 180c60b897dSRiver Riddle if (state.asmState) 181c60b897dSRiver Riddle referenceLocations.push_back(getToken().getLocRange()); 182c60b897dSRiver Riddle 183c60b897dSRiver Riddle // Parse the top-level reference. 184c60b897dSRiver Riddle std::string nameStr = getToken().getSymbolReference(); 185c60b897dSRiver Riddle consumeToken(Token::at_identifier); 186c60b897dSRiver Riddle 187c60b897dSRiver Riddle // Parse any nested references. 188c60b897dSRiver Riddle std::vector<FlatSymbolRefAttr> nestedRefs; 189c60b897dSRiver Riddle while (getToken().is(Token::colon)) { 190c60b897dSRiver Riddle // Check for the '::' prefix. 191c60b897dSRiver Riddle const char *curPointer = getToken().getLoc().getPointer(); 192c60b897dSRiver Riddle consumeToken(Token::colon); 193c60b897dSRiver Riddle if (!consumeIf(Token::colon)) { 194c60b897dSRiver Riddle if (getToken().isNot(Token::eof, Token::error)) { 195c60b897dSRiver Riddle state.lex.resetPointer(curPointer); 196c60b897dSRiver Riddle consumeToken(); 197c60b897dSRiver Riddle } 198c60b897dSRiver Riddle break; 199c60b897dSRiver Riddle } 200c60b897dSRiver Riddle // Parse the reference itself. 201c60b897dSRiver Riddle auto curLoc = getToken().getLoc(); 202c60b897dSRiver Riddle if (getToken().isNot(Token::at_identifier)) { 203c60b897dSRiver Riddle emitError(curLoc, "expected nested symbol reference identifier"); 204c60b897dSRiver Riddle return Attribute(); 205c60b897dSRiver Riddle } 206c60b897dSRiver Riddle 207c60b897dSRiver Riddle // If we are populating the assembly state, add the location for this 208c60b897dSRiver Riddle // reference. 209c60b897dSRiver Riddle if (state.asmState) 210c60b897dSRiver Riddle referenceLocations.push_back(getToken().getLocRange()); 211c60b897dSRiver Riddle 212c60b897dSRiver Riddle std::string nameStr = getToken().getSymbolReference(); 213c60b897dSRiver Riddle consumeToken(Token::at_identifier); 214c60b897dSRiver Riddle nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr)); 215c60b897dSRiver Riddle } 216c60b897dSRiver Riddle SymbolRefAttr symbolRefAttr = 217c60b897dSRiver Riddle SymbolRefAttr::get(getContext(), nameStr, nestedRefs); 218c60b897dSRiver Riddle 219c60b897dSRiver Riddle // If we are populating the assembly state, record this symbol reference. 220c60b897dSRiver Riddle if (state.asmState) 221c60b897dSRiver Riddle state.asmState->addUses(symbolRefAttr, referenceLocations); 222c60b897dSRiver Riddle return symbolRefAttr; 223c60b897dSRiver Riddle } 224c60b897dSRiver Riddle 225c60b897dSRiver Riddle // Parse a 'unit' attribute. 226c60b897dSRiver Riddle case Token::kw_unit: 227c60b897dSRiver Riddle consumeToken(Token::kw_unit); 228c60b897dSRiver Riddle return builder.getUnitAttr(); 229c60b897dSRiver Riddle 230c60b897dSRiver Riddle // Handle completion of an attribute. 231c60b897dSRiver Riddle case Token::code_complete: 232c60b897dSRiver Riddle if (getToken().isCodeCompletionFor(Token::hash_identifier)) 233c60b897dSRiver Riddle return parseExtendedAttr(type); 234c60b897dSRiver Riddle return codeCompleteAttribute(); 235c60b897dSRiver Riddle 236c60b897dSRiver Riddle default: 237c60b897dSRiver Riddle // Parse a type attribute. We parse `Optional` here to allow for providing a 238c60b897dSRiver Riddle // better error message. 239c60b897dSRiver Riddle Type type; 240c60b897dSRiver Riddle OptionalParseResult result = parseOptionalType(type); 2419750648cSKazu Hirata if (!result.has_value()) 242c60b897dSRiver Riddle return emitWrongTokenError("expected attribute value"), Attribute(); 243c60b897dSRiver Riddle return failed(*result) ? Attribute() : TypeAttr::get(type); 244c60b897dSRiver Riddle } 245c60b897dSRiver Riddle } 246c60b897dSRiver Riddle 247c60b897dSRiver Riddle /// Parse an optional attribute with the provided type. 248c60b897dSRiver Riddle OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute, 249c60b897dSRiver Riddle Type type) { 250c60b897dSRiver Riddle switch (getToken().getKind()) { 251c60b897dSRiver Riddle case Token::at_identifier: 252c60b897dSRiver Riddle case Token::floatliteral: 253c60b897dSRiver Riddle case Token::integer: 254c60b897dSRiver Riddle case Token::hash_identifier: 255c60b897dSRiver Riddle case Token::kw_affine_map: 256c60b897dSRiver Riddle case Token::kw_affine_set: 257c60b897dSRiver Riddle case Token::kw_dense: 258995ab929SRiver Riddle case Token::kw_dense_resource: 259c60b897dSRiver Riddle case Token::kw_false: 260c60b897dSRiver Riddle case Token::kw_loc: 261c60b897dSRiver Riddle case Token::kw_sparse: 262c60b897dSRiver Riddle case Token::kw_true: 263c60b897dSRiver Riddle case Token::kw_unit: 264c60b897dSRiver Riddle case Token::l_brace: 265c60b897dSRiver Riddle case Token::l_square: 266c60b897dSRiver Riddle case Token::minus: 267c60b897dSRiver Riddle case Token::string: 268c60b897dSRiver Riddle attribute = parseAttribute(type); 269c60b897dSRiver Riddle return success(attribute != nullptr); 270c60b897dSRiver Riddle 271c60b897dSRiver Riddle default: 272c60b897dSRiver Riddle // Parse an optional type attribute. 273c60b897dSRiver Riddle Type type; 274c60b897dSRiver Riddle OptionalParseResult result = parseOptionalType(type); 2759750648cSKazu Hirata if (result.has_value() && succeeded(*result)) 276c60b897dSRiver Riddle attribute = TypeAttr::get(type); 277c60b897dSRiver Riddle return result; 278c60b897dSRiver Riddle } 279c60b897dSRiver Riddle } 280c60b897dSRiver Riddle OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute, 281c60b897dSRiver Riddle Type type) { 282c60b897dSRiver Riddle return parseOptionalAttributeWithToken(Token::l_square, attribute, type); 283c60b897dSRiver Riddle } 284c60b897dSRiver Riddle OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute, 285c60b897dSRiver Riddle Type type) { 286c60b897dSRiver Riddle return parseOptionalAttributeWithToken(Token::string, attribute, type); 287c60b897dSRiver Riddle } 2888bb8421bSRiver Riddle OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result, 2898bb8421bSRiver Riddle Type type) { 2908bb8421bSRiver Riddle return parseOptionalAttributeWithToken(Token::at_identifier, result, type); 2918bb8421bSRiver Riddle } 292c60b897dSRiver Riddle 293c60b897dSRiver Riddle /// Attribute dictionary. 294c60b897dSRiver Riddle /// 295c60b897dSRiver Riddle /// attribute-dict ::= `{` `}` 296c60b897dSRiver Riddle /// | `{` attribute-entry (`,` attribute-entry)* `}` 297c60b897dSRiver Riddle /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value 298c60b897dSRiver Riddle /// 299c60b897dSRiver Riddle ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { 300c60b897dSRiver Riddle llvm::SmallDenseSet<StringAttr> seenKeys; 301c60b897dSRiver Riddle auto parseElt = [&]() -> ParseResult { 302c60b897dSRiver Riddle // The name of an attribute can either be a bare identifier, or a string. 303d8be2081SKazu Hirata std::optional<StringAttr> nameId; 304c60b897dSRiver Riddle if (getToken().is(Token::string)) 305c60b897dSRiver Riddle nameId = builder.getStringAttr(getToken().getStringValue()); 306c60b897dSRiver Riddle else if (getToken().isAny(Token::bare_identifier, Token::inttype) || 307c60b897dSRiver Riddle getToken().isKeyword()) 308c60b897dSRiver Riddle nameId = builder.getStringAttr(getTokenSpelling()); 309c60b897dSRiver Riddle else 310c60b897dSRiver Riddle return emitWrongTokenError("expected attribute name"); 311c60b897dSRiver Riddle 312398e48a7SAdrian Kuegel if (nameId->empty()) 313c60b897dSRiver Riddle return emitError("expected valid attribute name"); 314c60b897dSRiver Riddle 315c60b897dSRiver Riddle if (!seenKeys.insert(*nameId).second) 316c60b897dSRiver Riddle return emitError("duplicate key '") 317c60b897dSRiver Riddle << nameId->getValue() << "' in dictionary attribute"; 318c60b897dSRiver Riddle consumeToken(); 319c60b897dSRiver Riddle 320c60b897dSRiver Riddle // Lazy load a dialect in the context if there is a possible namespace. 321c60b897dSRiver Riddle auto splitName = nameId->strref().split('.'); 322c60b897dSRiver Riddle if (!splitName.second.empty()) 323c60b897dSRiver Riddle getContext()->getOrLoadDialect(splitName.first); 324c60b897dSRiver Riddle 325c60b897dSRiver Riddle // Try to parse the '=' for the attribute value. 326c60b897dSRiver Riddle if (!consumeIf(Token::equal)) { 327c60b897dSRiver Riddle // If there is no '=', we treat this as a unit attribute. 328c60b897dSRiver Riddle attributes.push_back({*nameId, builder.getUnitAttr()}); 329c60b897dSRiver Riddle return success(); 330c60b897dSRiver Riddle } 331c60b897dSRiver Riddle 332c60b897dSRiver Riddle auto attr = parseAttribute(); 333c60b897dSRiver Riddle if (!attr) 334c60b897dSRiver Riddle return failure(); 335c60b897dSRiver Riddle attributes.push_back({*nameId, attr}); 336c60b897dSRiver Riddle return success(); 337c60b897dSRiver Riddle }; 338c60b897dSRiver Riddle 339c60b897dSRiver Riddle return parseCommaSeparatedList(Delimiter::Braces, parseElt, 340c60b897dSRiver Riddle " in attribute dictionary"); 341c60b897dSRiver Riddle } 342c60b897dSRiver Riddle 343c60b897dSRiver Riddle /// Parse a float attribute. 344c60b897dSRiver Riddle Attribute Parser::parseFloatAttr(Type type, bool isNegative) { 345c60b897dSRiver Riddle auto val = getToken().getFloatingPointValue(); 346c60b897dSRiver Riddle if (!val) 347c60b897dSRiver Riddle return (emitError("floating point value too large for attribute"), nullptr); 348c60b897dSRiver Riddle consumeToken(Token::floatliteral); 349c60b897dSRiver Riddle if (!type) { 350c60b897dSRiver Riddle // Default to F64 when no type is specified. 351c60b897dSRiver Riddle if (!consumeIf(Token::colon)) 352c60b897dSRiver Riddle type = builder.getF64Type(); 353c60b897dSRiver Riddle else if (!(type = parseType())) 354c60b897dSRiver Riddle return nullptr; 355c60b897dSRiver Riddle } 3565550c821STres Popp if (!isa<FloatType>(type)) 357c60b897dSRiver Riddle return (emitError("floating point value not valid for specified type"), 358c60b897dSRiver Riddle nullptr); 359c60b897dSRiver Riddle return FloatAttr::get(type, isNegative ? -*val : *val); 360c60b897dSRiver Riddle } 361c60b897dSRiver Riddle 362c60b897dSRiver Riddle /// Construct an APint from a parsed value, a known attribute type and 363c60b897dSRiver Riddle /// sign. 3640a81ace0SKazu Hirata static std::optional<APInt> buildAttributeAPInt(Type type, bool isNegative, 365c60b897dSRiver Riddle StringRef spelling) { 366c60b897dSRiver Riddle // Parse the integer value into an APInt that is big enough to hold the value. 367c60b897dSRiver Riddle APInt result; 368c60b897dSRiver Riddle bool isHex = spelling.size() > 1 && spelling[1] == 'x'; 369c60b897dSRiver Riddle if (spelling.getAsInteger(isHex ? 0 : 10, result)) 3701a36588eSKazu Hirata return std::nullopt; 371c60b897dSRiver Riddle 372c60b897dSRiver Riddle // Extend or truncate the bitwidth to the right size. 373c60b897dSRiver Riddle unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth 374c60b897dSRiver Riddle : type.getIntOrFloatBitWidth(); 375c60b897dSRiver Riddle 376c60b897dSRiver Riddle if (width > result.getBitWidth()) { 377c60b897dSRiver Riddle result = result.zext(width); 378c60b897dSRiver Riddle } else if (width < result.getBitWidth()) { 379c60b897dSRiver Riddle // The parser can return an unnecessarily wide result with leading zeros. 380c60b897dSRiver Riddle // This isn't a problem, but truncating off bits is bad. 381f8f3db27SKazu Hirata if (result.countl_zero() < result.getBitWidth() - width) 3821a36588eSKazu Hirata return std::nullopt; 383c60b897dSRiver Riddle 384c60b897dSRiver Riddle result = result.trunc(width); 385c60b897dSRiver Riddle } 386c60b897dSRiver Riddle 387c60b897dSRiver Riddle if (width == 0) { 388c60b897dSRiver Riddle // 0 bit integers cannot be negative and manipulation of their sign bit will 389c60b897dSRiver Riddle // assert, so short-cut validation here. 390c60b897dSRiver Riddle if (isNegative) 3911a36588eSKazu Hirata return std::nullopt; 392c60b897dSRiver Riddle } else if (isNegative) { 393c60b897dSRiver Riddle // The value is negative, we have an overflow if the sign bit is not set 394c60b897dSRiver Riddle // in the negated apInt. 395c60b897dSRiver Riddle result.negate(); 396c60b897dSRiver Riddle if (!result.isSignBitSet()) 3971a36588eSKazu Hirata return std::nullopt; 398c60b897dSRiver Riddle } else if ((type.isSignedInteger() || type.isIndex()) && 399c60b897dSRiver Riddle result.isSignBitSet()) { 400c60b897dSRiver Riddle // The value is a positive signed integer or index, 401c60b897dSRiver Riddle // we have an overflow if the sign bit is set. 4021a36588eSKazu Hirata return std::nullopt; 403c60b897dSRiver Riddle } 404c60b897dSRiver Riddle 405c60b897dSRiver Riddle return result; 406c60b897dSRiver Riddle } 407c60b897dSRiver Riddle 408c60b897dSRiver Riddle /// Parse a decimal or a hexadecimal literal, which can be either an integer 409c60b897dSRiver Riddle /// or a float attribute. 410c60b897dSRiver Riddle Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { 411c60b897dSRiver Riddle Token tok = getToken(); 412c60b897dSRiver Riddle StringRef spelling = tok.getSpelling(); 413c60b897dSRiver Riddle SMLoc loc = tok.getLoc(); 414c60b897dSRiver Riddle 415c60b897dSRiver Riddle consumeToken(Token::integer); 416c60b897dSRiver Riddle if (!type) { 417c60b897dSRiver Riddle // Default to i64 if not type is specified. 418c60b897dSRiver Riddle if (!consumeIf(Token::colon)) 419c60b897dSRiver Riddle type = builder.getIntegerType(64); 420c60b897dSRiver Riddle else if (!(type = parseType())) 421c60b897dSRiver Riddle return nullptr; 422c60b897dSRiver Riddle } 423c60b897dSRiver Riddle 4245550c821STres Popp if (auto floatType = dyn_cast<FloatType>(type)) { 4250a81ace0SKazu Hirata std::optional<APFloat> result; 426c60b897dSRiver Riddle if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative, 427*4548bff0SMatthias Springer floatType.getFloatSemantics()))) 428c60b897dSRiver Riddle return Attribute(); 429c60b897dSRiver Riddle return FloatAttr::get(floatType, *result); 430c60b897dSRiver Riddle } 431c60b897dSRiver Riddle 4325550c821STres Popp if (!isa<IntegerType, IndexType>(type)) 433c60b897dSRiver Riddle return emitError(loc, "integer literal not valid for specified type"), 434c60b897dSRiver Riddle nullptr; 435c60b897dSRiver Riddle 436c60b897dSRiver Riddle if (isNegative && type.isUnsignedInteger()) { 437c60b897dSRiver Riddle emitError(loc, 438c60b897dSRiver Riddle "negative integer literal not valid for unsigned integer type"); 439c60b897dSRiver Riddle return nullptr; 440c60b897dSRiver Riddle } 441c60b897dSRiver Riddle 4420a81ace0SKazu Hirata std::optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling); 443c60b897dSRiver Riddle if (!apInt) 444c60b897dSRiver Riddle return emitError(loc, "integer constant out of range for attribute"), 445c60b897dSRiver Riddle nullptr; 446c60b897dSRiver Riddle return builder.getIntegerAttr(type, *apInt); 447c60b897dSRiver Riddle } 448c60b897dSRiver Riddle 449c60b897dSRiver Riddle //===----------------------------------------------------------------------===// 450c60b897dSRiver Riddle // TensorLiteralParser 451c60b897dSRiver Riddle //===----------------------------------------------------------------------===// 452c60b897dSRiver Riddle 453c60b897dSRiver Riddle /// Parse elements values stored within a hex string. On success, the values are 454c60b897dSRiver Riddle /// stored into 'result'. 455c60b897dSRiver Riddle static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, 456c60b897dSRiver Riddle std::string &result) { 4570a81ace0SKazu Hirata if (std::optional<std::string> value = tok.getHexStringValue()) { 458c60b897dSRiver Riddle result = std::move(*value); 459c60b897dSRiver Riddle return success(); 460c60b897dSRiver Riddle } 461c60b897dSRiver Riddle return parser.emitError( 462c60b897dSRiver Riddle tok.getLoc(), "expected string containing hex digits starting with `0x`"); 463c60b897dSRiver Riddle } 464c60b897dSRiver Riddle 465c60b897dSRiver Riddle namespace { 466c60b897dSRiver Riddle /// This class implements a parser for TensorLiterals. A tensor literal is 467c60b897dSRiver Riddle /// either a single element (e.g, 5) or a multi-dimensional list of elements 468c60b897dSRiver Riddle /// (e.g., [[5, 5]]). 469c60b897dSRiver Riddle class TensorLiteralParser { 470c60b897dSRiver Riddle public: 471c60b897dSRiver Riddle TensorLiteralParser(Parser &p) : p(p) {} 472c60b897dSRiver Riddle 473c60b897dSRiver Riddle /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser 474c60b897dSRiver Riddle /// may also parse a tensor literal that is store as a hex string. 475c60b897dSRiver Riddle ParseResult parse(bool allowHex); 476c60b897dSRiver Riddle 477c60b897dSRiver Riddle /// Build a dense attribute instance with the parsed elements and the given 478c60b897dSRiver Riddle /// shaped type. 479c60b897dSRiver Riddle DenseElementsAttr getAttr(SMLoc loc, ShapedType type); 480c60b897dSRiver Riddle 481c60b897dSRiver Riddle ArrayRef<int64_t> getShape() const { return shape; } 482c60b897dSRiver Riddle 483c60b897dSRiver Riddle private: 484c60b897dSRiver Riddle /// Get the parsed elements for an integer attribute. 485c60b897dSRiver Riddle ParseResult getIntAttrElements(SMLoc loc, Type eltTy, 486c60b897dSRiver Riddle std::vector<APInt> &intValues); 487c60b897dSRiver Riddle 488c60b897dSRiver Riddle /// Get the parsed elements for a float attribute. 489c60b897dSRiver Riddle ParseResult getFloatAttrElements(SMLoc loc, FloatType eltTy, 490c60b897dSRiver Riddle std::vector<APFloat> &floatValues); 491c60b897dSRiver Riddle 492c60b897dSRiver Riddle /// Build a Dense String attribute for the given type. 493c60b897dSRiver Riddle DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy); 494c60b897dSRiver Riddle 495c60b897dSRiver Riddle /// Build a Dense attribute with hex data for the given type. 496c60b897dSRiver Riddle DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type); 497c60b897dSRiver Riddle 498c60b897dSRiver Riddle /// Parse a single element, returning failure if it isn't a valid element 499c60b897dSRiver Riddle /// literal. For example: 500c60b897dSRiver Riddle /// parseElement(1) -> Success, 1 501c60b897dSRiver Riddle /// parseElement([1]) -> Failure 502c60b897dSRiver Riddle ParseResult parseElement(); 503c60b897dSRiver Riddle 504c60b897dSRiver Riddle /// Parse a list of either lists or elements, returning the dimensions of the 505c60b897dSRiver Riddle /// parsed sub-tensors in dims. For example: 506c60b897dSRiver Riddle /// parseList([1, 2, 3]) -> Success, [3] 507c60b897dSRiver Riddle /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 508c60b897dSRiver Riddle /// parseList([[1, 2], 3]) -> Failure 509c60b897dSRiver Riddle /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 510c60b897dSRiver Riddle ParseResult parseList(SmallVectorImpl<int64_t> &dims); 511c60b897dSRiver Riddle 512c60b897dSRiver Riddle /// Parse a literal that was printed as a hex string. 513c60b897dSRiver Riddle ParseResult parseHexElements(); 514c60b897dSRiver Riddle 515c60b897dSRiver Riddle Parser &p; 516c60b897dSRiver Riddle 517c60b897dSRiver Riddle /// The shape inferred from the parsed elements. 518c60b897dSRiver Riddle SmallVector<int64_t, 4> shape; 519c60b897dSRiver Riddle 520c60b897dSRiver Riddle /// Storage used when parsing elements, this is a pair of <is_negated, token>. 521c60b897dSRiver Riddle std::vector<std::pair<bool, Token>> storage; 522c60b897dSRiver Riddle 523c60b897dSRiver Riddle /// Storage used when parsing elements that were stored as hex values. 524d8be2081SKazu Hirata std::optional<Token> hexStorage; 525c60b897dSRiver Riddle }; 526c60b897dSRiver Riddle } // namespace 527c60b897dSRiver Riddle 528c60b897dSRiver Riddle /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser 529c60b897dSRiver Riddle /// may also parse a tensor literal that is store as a hex string. 530c60b897dSRiver Riddle ParseResult TensorLiteralParser::parse(bool allowHex) { 531c60b897dSRiver Riddle // If hex is allowed, check for a string literal. 532c60b897dSRiver Riddle if (allowHex && p.getToken().is(Token::string)) { 533c60b897dSRiver Riddle hexStorage = p.getToken(); 534c60b897dSRiver Riddle p.consumeToken(Token::string); 535c60b897dSRiver Riddle return success(); 536c60b897dSRiver Riddle } 537c60b897dSRiver Riddle // Otherwise, parse a list or an individual element. 538c60b897dSRiver Riddle if (p.getToken().is(Token::l_square)) 539c60b897dSRiver Riddle return parseList(shape); 540c60b897dSRiver Riddle return parseElement(); 541c60b897dSRiver Riddle } 542c60b897dSRiver Riddle 543c60b897dSRiver Riddle /// Build a dense attribute instance with the parsed elements and the given 544c60b897dSRiver Riddle /// shaped type. 545c60b897dSRiver Riddle DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { 546c60b897dSRiver Riddle Type eltType = type.getElementType(); 547c60b897dSRiver Riddle 548c60b897dSRiver Riddle // Check to see if we parse the literal from a hex string. 549c60b897dSRiver Riddle if (hexStorage && 5505550c821STres Popp (eltType.isIntOrIndexOrFloat() || isa<ComplexType>(eltType))) 551c60b897dSRiver Riddle return getHexAttr(loc, type); 552c60b897dSRiver Riddle 553c60b897dSRiver Riddle // Check that the parsed storage size has the same number of elements to the 554c60b897dSRiver Riddle // type, or is a known splat. 555c60b897dSRiver Riddle if (!shape.empty() && getShape() != type.getShape()) { 556c60b897dSRiver Riddle p.emitError(loc) << "inferred shape of elements literal ([" << getShape() 557c60b897dSRiver Riddle << "]) does not match type ([" << type.getShape() << "])"; 558c60b897dSRiver Riddle return nullptr; 559c60b897dSRiver Riddle } 560c60b897dSRiver Riddle 561c60b897dSRiver Riddle // Handle the case where no elements were parsed. 562c60b897dSRiver Riddle if (!hexStorage && storage.empty() && type.getNumElements()) { 563c60b897dSRiver Riddle p.emitError(loc) << "parsed zero elements, but type (" << type 564c60b897dSRiver Riddle << ") expected at least 1"; 565c60b897dSRiver Riddle return nullptr; 566c60b897dSRiver Riddle } 567c60b897dSRiver Riddle 568c60b897dSRiver Riddle // Handle complex types in the specific element type cases below. 569c60b897dSRiver Riddle bool isComplex = false; 5705550c821STres Popp if (ComplexType complexTy = dyn_cast<ComplexType>(eltType)) { 571c60b897dSRiver Riddle eltType = complexTy.getElementType(); 572c60b897dSRiver Riddle isComplex = true; 573c60b897dSRiver Riddle } 574c60b897dSRiver Riddle 575c60b897dSRiver Riddle // Handle integer and index types. 576c60b897dSRiver Riddle if (eltType.isIntOrIndex()) { 577c60b897dSRiver Riddle std::vector<APInt> intValues; 578c60b897dSRiver Riddle if (failed(getIntAttrElements(loc, eltType, intValues))) 579c60b897dSRiver Riddle return nullptr; 580c60b897dSRiver Riddle if (isComplex) { 581c60b897dSRiver Riddle // If this is a complex, treat the parsed values as complex values. 582984b800aSserge-sans-paille auto complexData = llvm::ArrayRef( 583c60b897dSRiver Riddle reinterpret_cast<std::complex<APInt> *>(intValues.data()), 584c60b897dSRiver Riddle intValues.size() / 2); 585c60b897dSRiver Riddle return DenseElementsAttr::get(type, complexData); 586c60b897dSRiver Riddle } 587c60b897dSRiver Riddle return DenseElementsAttr::get(type, intValues); 588c60b897dSRiver Riddle } 589c60b897dSRiver Riddle // Handle floating point types. 5905550c821STres Popp if (FloatType floatTy = dyn_cast<FloatType>(eltType)) { 591c60b897dSRiver Riddle std::vector<APFloat> floatValues; 592c60b897dSRiver Riddle if (failed(getFloatAttrElements(loc, floatTy, floatValues))) 593c60b897dSRiver Riddle return nullptr; 594c60b897dSRiver Riddle if (isComplex) { 595c60b897dSRiver Riddle // If this is a complex, treat the parsed values as complex values. 596984b800aSserge-sans-paille auto complexData = llvm::ArrayRef( 597c60b897dSRiver Riddle reinterpret_cast<std::complex<APFloat> *>(floatValues.data()), 598c60b897dSRiver Riddle floatValues.size() / 2); 599c60b897dSRiver Riddle return DenseElementsAttr::get(type, complexData); 600c60b897dSRiver Riddle } 601c60b897dSRiver Riddle return DenseElementsAttr::get(type, floatValues); 602c60b897dSRiver Riddle } 603c60b897dSRiver Riddle 604c60b897dSRiver Riddle // Other types are assumed to be string representations. 605c60b897dSRiver Riddle return getStringAttr(loc, type, type.getElementType()); 606c60b897dSRiver Riddle } 607c60b897dSRiver Riddle 608c60b897dSRiver Riddle /// Build a Dense Integer attribute for the given type. 609c60b897dSRiver Riddle ParseResult 610c60b897dSRiver Riddle TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy, 611c60b897dSRiver Riddle std::vector<APInt> &intValues) { 612c60b897dSRiver Riddle intValues.reserve(storage.size()); 613c60b897dSRiver Riddle bool isUintType = eltTy.isUnsignedInteger(); 614c60b897dSRiver Riddle for (const auto &signAndToken : storage) { 615c60b897dSRiver Riddle bool isNegative = signAndToken.first; 616c60b897dSRiver Riddle const Token &token = signAndToken.second; 617c60b897dSRiver Riddle auto tokenLoc = token.getLoc(); 618c60b897dSRiver Riddle 619c60b897dSRiver Riddle if (isNegative && isUintType) { 620c60b897dSRiver Riddle return p.emitError(tokenLoc) 621c60b897dSRiver Riddle << "expected unsigned integer elements, but parsed negative value"; 622c60b897dSRiver Riddle } 623c60b897dSRiver Riddle 624c60b897dSRiver Riddle // Check to see if floating point values were parsed. 625c60b897dSRiver Riddle if (token.is(Token::floatliteral)) { 626c60b897dSRiver Riddle return p.emitError(tokenLoc) 627c60b897dSRiver Riddle << "expected integer elements, but parsed floating-point"; 628c60b897dSRiver Riddle } 629c60b897dSRiver Riddle 630c60b897dSRiver Riddle assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && 631c60b897dSRiver Riddle "unexpected token type"); 632c60b897dSRiver Riddle if (token.isAny(Token::kw_true, Token::kw_false)) { 633c60b897dSRiver Riddle if (!eltTy.isInteger(1)) { 634c60b897dSRiver Riddle return p.emitError(tokenLoc) 635c60b897dSRiver Riddle << "expected i1 type for 'true' or 'false' values"; 636c60b897dSRiver Riddle } 637c60b897dSRiver Riddle APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false); 638c60b897dSRiver Riddle intValues.push_back(apInt); 639c60b897dSRiver Riddle continue; 640c60b897dSRiver Riddle } 641c60b897dSRiver Riddle 642c60b897dSRiver Riddle // Create APInt values for each element with the correct bitwidth. 6430a81ace0SKazu Hirata std::optional<APInt> apInt = 644c60b897dSRiver Riddle buildAttributeAPInt(eltTy, isNegative, token.getSpelling()); 645c60b897dSRiver Riddle if (!apInt) 646c60b897dSRiver Riddle return p.emitError(tokenLoc, "integer constant out of range for type"); 647c60b897dSRiver Riddle intValues.push_back(*apInt); 648c60b897dSRiver Riddle } 649c60b897dSRiver Riddle return success(); 650c60b897dSRiver Riddle } 651c60b897dSRiver Riddle 652c60b897dSRiver Riddle /// Build a Dense Float attribute for the given type. 653c60b897dSRiver Riddle ParseResult 654c60b897dSRiver Riddle TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy, 655c60b897dSRiver Riddle std::vector<APFloat> &floatValues) { 656c60b897dSRiver Riddle floatValues.reserve(storage.size()); 657c60b897dSRiver Riddle for (const auto &signAndToken : storage) { 658c60b897dSRiver Riddle bool isNegative = signAndToken.first; 659c60b897dSRiver Riddle const Token &token = signAndToken.second; 6600a81ace0SKazu Hirata std::optional<APFloat> result; 661*4548bff0SMatthias Springer if (failed(p.parseFloatFromLiteral(result, token, isNegative, 662*4548bff0SMatthias Springer eltTy.getFloatSemantics()))) 663c60b897dSRiver Riddle return failure(); 664c60b897dSRiver Riddle floatValues.push_back(*result); 665c60b897dSRiver Riddle } 666c60b897dSRiver Riddle return success(); 667c60b897dSRiver Riddle } 668c60b897dSRiver Riddle 669c60b897dSRiver Riddle /// Build a Dense String attribute for the given type. 670c60b897dSRiver Riddle DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type, 671c60b897dSRiver Riddle Type eltTy) { 672c60b897dSRiver Riddle if (hexStorage.has_value()) { 6734913e5daSFangrui Song auto stringValue = hexStorage->getStringValue(); 674c60b897dSRiver Riddle return DenseStringElementsAttr::get(type, {stringValue}); 675c60b897dSRiver Riddle } 676c60b897dSRiver Riddle 677c60b897dSRiver Riddle std::vector<std::string> stringValues; 678c60b897dSRiver Riddle std::vector<StringRef> stringRefValues; 679c60b897dSRiver Riddle stringValues.reserve(storage.size()); 680c60b897dSRiver Riddle stringRefValues.reserve(storage.size()); 681c60b897dSRiver Riddle 682c60b897dSRiver Riddle for (auto val : storage) { 683c60b897dSRiver Riddle stringValues.push_back(val.second.getStringValue()); 684c60b897dSRiver Riddle stringRefValues.emplace_back(stringValues.back()); 685c60b897dSRiver Riddle } 686c60b897dSRiver Riddle 687c60b897dSRiver Riddle return DenseStringElementsAttr::get(type, stringRefValues); 688c60b897dSRiver Riddle } 689c60b897dSRiver Riddle 690c60b897dSRiver Riddle /// Build a Dense attribute with hex data for the given type. 691c60b897dSRiver Riddle DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) { 692c60b897dSRiver Riddle Type elementType = type.getElementType(); 6935550c821STres Popp if (!elementType.isIntOrIndexOrFloat() && !isa<ComplexType>(elementType)) { 694c60b897dSRiver Riddle p.emitError(loc) 695c60b897dSRiver Riddle << "expected floating-point, integer, or complex element type, got " 696c60b897dSRiver Riddle << elementType; 697c60b897dSRiver Riddle return nullptr; 698c60b897dSRiver Riddle } 699c60b897dSRiver Riddle 700c60b897dSRiver Riddle std::string data; 701c60b897dSRiver Riddle if (parseElementAttrHexValues(p, *hexStorage, data)) 702c60b897dSRiver Riddle return nullptr; 703c60b897dSRiver Riddle 704c60b897dSRiver Riddle ArrayRef<char> rawData(data.data(), data.size()); 705c60b897dSRiver Riddle bool detectedSplat = false; 706c60b897dSRiver Riddle if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) { 707c60b897dSRiver Riddle p.emitError(loc) << "elements hex data size is invalid for provided type: " 708c60b897dSRiver Riddle << type; 709c60b897dSRiver Riddle return nullptr; 710c60b897dSRiver Riddle } 711c60b897dSRiver Riddle 7126b31b026SKazu Hirata if (llvm::endianness::native == llvm::endianness::big) { 713c60b897dSRiver Riddle // Convert endianess in big-endian(BE) machines. `rawData` is 714c60b897dSRiver Riddle // little-endian(LE) because HEX in raw data of dense element attribute 715c60b897dSRiver Riddle // is always LE format. It is converted into BE here to be used in BE 716c60b897dSRiver Riddle // machines. 717c60b897dSRiver Riddle SmallVector<char, 64> outDataVec(rawData.size()); 718c60b897dSRiver Riddle MutableArrayRef<char> convRawData(outDataVec); 719c60b897dSRiver Riddle DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 720c60b897dSRiver Riddle rawData, convRawData, type); 721c60b897dSRiver Riddle return DenseElementsAttr::getFromRawBuffer(type, convRawData); 722c60b897dSRiver Riddle } 723c60b897dSRiver Riddle 724c60b897dSRiver Riddle return DenseElementsAttr::getFromRawBuffer(type, rawData); 725c60b897dSRiver Riddle } 726c60b897dSRiver Riddle 727c60b897dSRiver Riddle ParseResult TensorLiteralParser::parseElement() { 728c60b897dSRiver Riddle switch (p.getToken().getKind()) { 729c60b897dSRiver Riddle // Parse a boolean element. 730c60b897dSRiver Riddle case Token::kw_true: 731c60b897dSRiver Riddle case Token::kw_false: 732c60b897dSRiver Riddle case Token::floatliteral: 733c60b897dSRiver Riddle case Token::integer: 734c60b897dSRiver Riddle storage.emplace_back(/*isNegative=*/false, p.getToken()); 735c60b897dSRiver Riddle p.consumeToken(); 736c60b897dSRiver Riddle break; 737c60b897dSRiver Riddle 738c60b897dSRiver Riddle // Parse a signed integer or a negative floating-point element. 739c60b897dSRiver Riddle case Token::minus: 740c60b897dSRiver Riddle p.consumeToken(Token::minus); 741c60b897dSRiver Riddle if (!p.getToken().isAny(Token::floatliteral, Token::integer)) 742c60b897dSRiver Riddle return p.emitError("expected integer or floating point literal"); 743c60b897dSRiver Riddle storage.emplace_back(/*isNegative=*/true, p.getToken()); 744c60b897dSRiver Riddle p.consumeToken(); 745c60b897dSRiver Riddle break; 746c60b897dSRiver Riddle 747c60b897dSRiver Riddle case Token::string: 748c60b897dSRiver Riddle storage.emplace_back(/*isNegative=*/false, p.getToken()); 749c60b897dSRiver Riddle p.consumeToken(); 750c60b897dSRiver Riddle break; 751c60b897dSRiver Riddle 752c60b897dSRiver Riddle // Parse a complex element of the form '(' element ',' element ')'. 753c60b897dSRiver Riddle case Token::l_paren: 754c60b897dSRiver Riddle p.consumeToken(Token::l_paren); 755c60b897dSRiver Riddle if (parseElement() || 756c60b897dSRiver Riddle p.parseToken(Token::comma, "expected ',' between complex elements") || 757c60b897dSRiver Riddle parseElement() || 758c60b897dSRiver Riddle p.parseToken(Token::r_paren, "expected ')' after complex elements")) 759c60b897dSRiver Riddle return failure(); 760c60b897dSRiver Riddle break; 761c60b897dSRiver Riddle 762c60b897dSRiver Riddle default: 763c60b897dSRiver Riddle return p.emitError("expected element literal of primitive type"); 764c60b897dSRiver Riddle } 765c60b897dSRiver Riddle 766c60b897dSRiver Riddle return success(); 767c60b897dSRiver Riddle } 768c60b897dSRiver Riddle 769c60b897dSRiver Riddle /// Parse a list of either lists or elements, returning the dimensions of the 770c60b897dSRiver Riddle /// parsed sub-tensors in dims. For example: 771c60b897dSRiver Riddle /// parseList([1, 2, 3]) -> Success, [3] 772c60b897dSRiver Riddle /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] 773c60b897dSRiver Riddle /// parseList([[1, 2], 3]) -> Failure 774c60b897dSRiver Riddle /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure 775c60b897dSRiver Riddle ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) { 776c60b897dSRiver Riddle auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims, 777c60b897dSRiver Riddle const SmallVectorImpl<int64_t> &newDims) -> ParseResult { 778c60b897dSRiver Riddle if (prevDims == newDims) 779c60b897dSRiver Riddle return success(); 780c60b897dSRiver Riddle return p.emitError("tensor literal is invalid; ranks are not consistent " 781c60b897dSRiver Riddle "between elements"); 782c60b897dSRiver Riddle }; 783c60b897dSRiver Riddle 784c60b897dSRiver Riddle bool first = true; 785c60b897dSRiver Riddle SmallVector<int64_t, 4> newDims; 786c60b897dSRiver Riddle unsigned size = 0; 787c60b897dSRiver Riddle auto parseOneElement = [&]() -> ParseResult { 788c60b897dSRiver Riddle SmallVector<int64_t, 4> thisDims; 789c60b897dSRiver Riddle if (p.getToken().getKind() == Token::l_square) { 790c60b897dSRiver Riddle if (parseList(thisDims)) 791c60b897dSRiver Riddle return failure(); 792c60b897dSRiver Riddle } else if (parseElement()) { 793c60b897dSRiver Riddle return failure(); 794c60b897dSRiver Riddle } 795c60b897dSRiver Riddle ++size; 796c60b897dSRiver Riddle if (!first) 797c60b897dSRiver Riddle return checkDims(newDims, thisDims); 798c60b897dSRiver Riddle newDims = thisDims; 799c60b897dSRiver Riddle first = false; 800c60b897dSRiver Riddle return success(); 801c60b897dSRiver Riddle }; 802c60b897dSRiver Riddle if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOneElement)) 803c60b897dSRiver Riddle return failure(); 804c60b897dSRiver Riddle 805c60b897dSRiver Riddle // Return the sublists' dimensions with 'size' prepended. 806c60b897dSRiver Riddle dims.clear(); 807c60b897dSRiver Riddle dims.push_back(size); 808c60b897dSRiver Riddle dims.append(newDims.begin(), newDims.end()); 809c60b897dSRiver Riddle return success(); 810c60b897dSRiver Riddle } 811c60b897dSRiver Riddle 812c60b897dSRiver Riddle //===----------------------------------------------------------------------===// 813cec7e80eSJeff Niu // DenseArrayAttr Parser 814c60b897dSRiver Riddle //===----------------------------------------------------------------------===// 815c60b897dSRiver Riddle 816c60b897dSRiver Riddle namespace { 817cec7e80eSJeff Niu /// A generic dense array element parser. It parsers integer and floating point 818cec7e80eSJeff Niu /// elements. 819cec7e80eSJeff Niu class DenseArrayElementParser { 820c60b897dSRiver Riddle public: 821cec7e80eSJeff Niu explicit DenseArrayElementParser(Type type) : type(type) {} 822cec7e80eSJeff Niu 823cec7e80eSJeff Niu /// Parse an integer element. 824cec7e80eSJeff Niu ParseResult parseIntegerElement(Parser &p); 825cec7e80eSJeff Niu 826cec7e80eSJeff Niu /// Parse a floating point element. 827cec7e80eSJeff Niu ParseResult parseFloatElement(Parser &p); 828cec7e80eSJeff Niu 829cec7e80eSJeff Niu /// Convert the current contents to a dense array. 830c48e0cf0SJeff Niu DenseArrayAttr getAttr() { return DenseArrayAttr::get(type, size, rawData); } 831cec7e80eSJeff Niu 832cec7e80eSJeff Niu private: 833cec7e80eSJeff Niu /// Append the raw data of an APInt to the result. 834cec7e80eSJeff Niu void append(const APInt &data); 835cec7e80eSJeff Niu 836cec7e80eSJeff Niu /// The array element type. 837cec7e80eSJeff Niu Type type; 838cec7e80eSJeff Niu /// The resultant byte array representing the contents of the array. 839cec7e80eSJeff Niu std::vector<char> rawData; 840cec7e80eSJeff Niu /// The number of elements in the array. 841cec7e80eSJeff Niu int64_t size = 0; 842c60b897dSRiver Riddle }; 843c60b897dSRiver Riddle } // namespace 844c60b897dSRiver Riddle 845cec7e80eSJeff Niu void DenseArrayElementParser::append(const APInt &data) { 846cbf81558SVitaly Buka if (data.getBitWidth()) { 847cbf81558SVitaly Buka assert(data.getBitWidth() % 8 == 0); 848cec7e80eSJeff Niu unsigned byteSize = data.getBitWidth() / 8; 849cec7e80eSJeff Niu size_t offset = rawData.size(); 850cec7e80eSJeff Niu rawData.insert(rawData.end(), byteSize, 0); 851cec7e80eSJeff Niu llvm::StoreIntToMemory( 852cec7e80eSJeff Niu data, reinterpret_cast<uint8_t *>(rawData.data() + offset), byteSize); 853cbf81558SVitaly Buka } 854cec7e80eSJeff Niu ++size; 855cec7e80eSJeff Niu } 856cec7e80eSJeff Niu 857cec7e80eSJeff Niu ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) { 858cec7e80eSJeff Niu bool isNegative = p.consumeIf(Token::minus); 859cec7e80eSJeff Niu 860cec7e80eSJeff Niu // Parse an integer literal as an APInt. 8610a81ace0SKazu Hirata std::optional<APInt> value; 862cec7e80eSJeff Niu StringRef spelling = p.getToken().getSpelling(); 863cec7e80eSJeff Niu if (p.getToken().isAny(Token::kw_true, Token::kw_false)) { 864cec7e80eSJeff Niu if (!type.isInteger(1)) 865cec7e80eSJeff Niu return p.emitError("expected i1 type for 'true' or 'false' values"); 866cec7e80eSJeff Niu value = APInt(/*numBits=*/8, p.getToken().is(Token::kw_true), 867cec7e80eSJeff Niu !type.isUnsignedInteger()); 868cec7e80eSJeff Niu p.consumeToken(); 869cec7e80eSJeff Niu } else if (p.consumeIf(Token::integer)) { 870cec7e80eSJeff Niu value = buildAttributeAPInt(type, isNegative, spelling); 871cec7e80eSJeff Niu if (!value) 872cec7e80eSJeff Niu return p.emitError("integer constant out of range"); 873cec7e80eSJeff Niu } else { 874cec7e80eSJeff Niu return p.emitError("expected integer literal"); 875cec7e80eSJeff Niu } 876cec7e80eSJeff Niu append(*value); 877cec7e80eSJeff Niu return success(); 878cec7e80eSJeff Niu } 879cec7e80eSJeff Niu 880cec7e80eSJeff Niu ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) { 881cec7e80eSJeff Niu bool isNegative = p.consumeIf(Token::minus); 882cec7e80eSJeff Niu Token token = p.getToken(); 883*4548bff0SMatthias Springer std::optional<APFloat> fromIntLit; 884*4548bff0SMatthias Springer if (failed( 885*4548bff0SMatthias Springer p.parseFloatFromLiteral(fromIntLit, token, isNegative, 886*4548bff0SMatthias Springer cast<FloatType>(type).getFloatSemantics()))) 887cec7e80eSJeff Niu return failure(); 888*4548bff0SMatthias Springer p.consumeToken(); 889*4548bff0SMatthias Springer append(fromIntLit->bitcastToAPInt()); 890cec7e80eSJeff Niu return success(); 891cec7e80eSJeff Niu } 892cec7e80eSJeff Niu 893c60b897dSRiver Riddle /// Parse a dense array attribute. 8947a7c0697SJeff Niu Attribute Parser::parseDenseArrayAttr(Type attrType) { 8952092d143SJeff Niu consumeToken(Token::kw_array); 896cec7e80eSJeff Niu if (parseToken(Token::less, "expected '<' after 'array'")) 89796da738dSJeff Niu return {}; 898c60b897dSRiver Riddle 899cec7e80eSJeff Niu SMLoc typeLoc = getToken().getLoc(); 900c48e0cf0SJeff Niu Type eltType = parseType(); 901c48e0cf0SJeff Niu if (!eltType) { 902c48e0cf0SJeff Niu emitError(typeLoc, "expected an integer or floating point type"); 9037a7c0697SJeff Niu return {}; 9047a7c0697SJeff Niu } 9057a7c0697SJeff Niu 9067a7c0697SJeff Niu // Only bool or integer and floating point elements divisible by bytes are 9077a7c0697SJeff Niu // supported. 9087a7c0697SJeff Niu if (!eltType.isIntOrIndexOrFloat()) { 9097a7c0697SJeff Niu emitError(typeLoc, "expected integer or float type, got: ") << eltType; 9107a7c0697SJeff Niu return {}; 9117a7c0697SJeff Niu } 9127a7c0697SJeff Niu if (!eltType.isInteger(1) && eltType.getIntOrFloatBitWidth() % 8 != 0) { 913cec7e80eSJeff Niu emitError(typeLoc, "element type bitwidth must be a multiple of 8"); 914cec7e80eSJeff Niu return {}; 915cec7e80eSJeff Niu } 916cec7e80eSJeff Niu 917cec7e80eSJeff Niu // Check for empty list. 918c48e0cf0SJeff Niu if (consumeIf(Token::greater)) 919c48e0cf0SJeff Niu return DenseArrayAttr::get(eltType, 0, {}); 920c48e0cf0SJeff Niu 921c48e0cf0SJeff Niu if (parseToken(Token::colon, "expected ':' after dense array type")) 922cec7e80eSJeff Niu return {}; 923cec7e80eSJeff Niu 9247a7c0697SJeff Niu DenseArrayElementParser eltParser(eltType); 9257a7c0697SJeff Niu if (eltType.isIntOrIndex()) { 926cec7e80eSJeff Niu if (parseCommaSeparatedList( 927cec7e80eSJeff Niu [&] { return eltParser.parseIntegerElement(*this); })) 928cec7e80eSJeff Niu return {}; 929cec7e80eSJeff Niu } else { 930cec7e80eSJeff Niu if (parseCommaSeparatedList( 931cec7e80eSJeff Niu [&] { return eltParser.parseFloatElement(*this); })) 932cec7e80eSJeff Niu return {}; 933cec7e80eSJeff Niu } 9342092d143SJeff Niu if (parseToken(Token::greater, "expected '>' to close an array attribute")) 935c60b897dSRiver Riddle return {}; 936c48e0cf0SJeff Niu return eltParser.getAttr(); 937c60b897dSRiver Riddle } 938c60b897dSRiver Riddle 939c60b897dSRiver Riddle /// Parse a dense elements attribute. 940c60b897dSRiver Riddle Attribute Parser::parseDenseElementsAttr(Type attrType) { 941c60b897dSRiver Riddle auto attribLoc = getToken().getLoc(); 942c60b897dSRiver Riddle consumeToken(Token::kw_dense); 943c60b897dSRiver Riddle if (parseToken(Token::less, "expected '<' after 'dense'")) 944c60b897dSRiver Riddle return nullptr; 945c60b897dSRiver Riddle 946c60b897dSRiver Riddle // Parse the literal data if necessary. 947c60b897dSRiver Riddle TensorLiteralParser literalParser(*this); 948c60b897dSRiver Riddle if (!consumeIf(Token::greater)) { 949c60b897dSRiver Riddle if (literalParser.parse(/*allowHex=*/true) || 950c60b897dSRiver Riddle parseToken(Token::greater, "expected '>'")) 951c60b897dSRiver Riddle return nullptr; 952c60b897dSRiver Riddle } 953c60b897dSRiver Riddle 954c60b897dSRiver Riddle // If the type is specified `parseElementsLiteralType` will not parse a type. 955c60b897dSRiver Riddle // Use the attribute location as the location for error reporting in that 956c60b897dSRiver Riddle // case. 957c60b897dSRiver Riddle auto loc = attrType ? attribLoc : getToken().getLoc(); 958c60b897dSRiver Riddle auto type = parseElementsLiteralType(attrType); 959c60b897dSRiver Riddle if (!type) 960c60b897dSRiver Riddle return nullptr; 961c60b897dSRiver Riddle return literalParser.getAttr(loc, type); 962c60b897dSRiver Riddle } 963c60b897dSRiver Riddle 964995ab929SRiver Riddle Attribute Parser::parseDenseResourceElementsAttr(Type attrType) { 965995ab929SRiver Riddle auto loc = getToken().getLoc(); 966995ab929SRiver Riddle consumeToken(Token::kw_dense_resource); 967995ab929SRiver Riddle if (parseToken(Token::less, "expected '<' after 'dense_resource'")) 968995ab929SRiver Riddle return nullptr; 969995ab929SRiver Riddle 970995ab929SRiver Riddle // Parse the resource handle. 971995ab929SRiver Riddle FailureOr<AsmDialectResourceHandle> rawHandle = 972995ab929SRiver Riddle parseResourceHandle(getContext()->getLoadedDialect<BuiltinDialect>()); 973995ab929SRiver Riddle if (failed(rawHandle) || parseToken(Token::greater, "expected '>'")) 974995ab929SRiver Riddle return nullptr; 975995ab929SRiver Riddle 976995ab929SRiver Riddle auto *handle = dyn_cast<DenseResourceElementsHandle>(&*rawHandle); 977995ab929SRiver Riddle if (!handle) 978995ab929SRiver Riddle return emitError(loc, "invalid `dense_resource` handle type"), nullptr; 979995ab929SRiver Riddle 980995ab929SRiver Riddle // Parse the type of the attribute if the user didn't provide one. 981995ab929SRiver Riddle SMLoc typeLoc = loc; 982995ab929SRiver Riddle if (!attrType) { 983995ab929SRiver Riddle typeLoc = getToken().getLoc(); 984995ab929SRiver Riddle if (parseToken(Token::colon, "expected ':'") || !(attrType = parseType())) 985995ab929SRiver Riddle return nullptr; 986995ab929SRiver Riddle } 987995ab929SRiver Riddle 9885550c821STres Popp ShapedType shapedType = dyn_cast<ShapedType>(attrType); 989995ab929SRiver Riddle if (!shapedType) { 990995ab929SRiver Riddle emitError(typeLoc, "`dense_resource` expected a shaped type"); 991995ab929SRiver Riddle return nullptr; 992995ab929SRiver Riddle } 993995ab929SRiver Riddle 994995ab929SRiver Riddle return DenseResourceElementsAttr::get(shapedType, *handle); 995995ab929SRiver Riddle } 996995ab929SRiver Riddle 997c60b897dSRiver Riddle /// Shaped type for elements attribute. 998c60b897dSRiver Riddle /// 999c60b897dSRiver Riddle /// elements-literal-type ::= vector-type | ranked-tensor-type 1000c60b897dSRiver Riddle /// 1001c60b897dSRiver Riddle /// This method also checks the type has static shape. 1002c60b897dSRiver Riddle ShapedType Parser::parseElementsLiteralType(Type type) { 1003c60b897dSRiver Riddle // If the user didn't provide a type, parse the colon type for the literal. 1004c60b897dSRiver Riddle if (!type) { 1005c60b897dSRiver Riddle if (parseToken(Token::colon, "expected ':'")) 1006c60b897dSRiver Riddle return nullptr; 1007c60b897dSRiver Riddle if (!(type = parseType())) 1008c60b897dSRiver Riddle return nullptr; 1009c60b897dSRiver Riddle } 1010c60b897dSRiver Riddle 10115550c821STres Popp auto sType = dyn_cast<ShapedType>(type); 1012a5c46bf9SJeff Niu if (!sType) { 1013a5c46bf9SJeff Niu emitError("elements literal must be a shaped type"); 1014c60b897dSRiver Riddle return nullptr; 1015c60b897dSRiver Riddle } 1016c60b897dSRiver Riddle 1017c60b897dSRiver Riddle if (!sType.hasStaticShape()) 1018c60b897dSRiver Riddle return (emitError("elements literal type must have static shape"), nullptr); 1019c60b897dSRiver Riddle 1020c60b897dSRiver Riddle return sType; 1021c60b897dSRiver Riddle } 1022c60b897dSRiver Riddle 1023c60b897dSRiver Riddle /// Parse a sparse elements attribute. 1024c60b897dSRiver Riddle Attribute Parser::parseSparseElementsAttr(Type attrType) { 1025c60b897dSRiver Riddle SMLoc loc = getToken().getLoc(); 1026c60b897dSRiver Riddle consumeToken(Token::kw_sparse); 1027c60b897dSRiver Riddle if (parseToken(Token::less, "Expected '<' after 'sparse'")) 1028c60b897dSRiver Riddle return nullptr; 1029c60b897dSRiver Riddle 1030c60b897dSRiver Riddle // Check for the case where all elements are sparse. The indices are 1031c60b897dSRiver Riddle // represented by a 2-dimensional shape where the second dimension is the rank 1032c60b897dSRiver Riddle // of the type. 1033c60b897dSRiver Riddle Type indiceEltType = builder.getIntegerType(64); 1034c60b897dSRiver Riddle if (consumeIf(Token::greater)) { 1035c60b897dSRiver Riddle ShapedType type = parseElementsLiteralType(attrType); 1036c60b897dSRiver Riddle if (!type) 1037c60b897dSRiver Riddle return nullptr; 1038c60b897dSRiver Riddle 1039c60b897dSRiver Riddle // Construct the sparse elements attr using zero element indice/value 1040c60b897dSRiver Riddle // attributes. 1041c60b897dSRiver Riddle ShapedType indicesType = 1042c60b897dSRiver Riddle RankedTensorType::get({0, type.getRank()}, indiceEltType); 1043c60b897dSRiver Riddle ShapedType valuesType = RankedTensorType::get({0}, type.getElementType()); 1044c60b897dSRiver Riddle return getChecked<SparseElementsAttr>( 1045c60b897dSRiver Riddle loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()), 1046c60b897dSRiver Riddle DenseElementsAttr::get(valuesType, ArrayRef<Attribute>())); 1047c60b897dSRiver Riddle } 1048c60b897dSRiver Riddle 1049c60b897dSRiver Riddle /// Parse the indices. We don't allow hex values here as we may need to use 1050c60b897dSRiver Riddle /// the inferred shape. 1051c60b897dSRiver Riddle auto indicesLoc = getToken().getLoc(); 1052c60b897dSRiver Riddle TensorLiteralParser indiceParser(*this); 1053c60b897dSRiver Riddle if (indiceParser.parse(/*allowHex=*/false)) 1054c60b897dSRiver Riddle return nullptr; 1055c60b897dSRiver Riddle 1056c60b897dSRiver Riddle if (parseToken(Token::comma, "expected ','")) 1057c60b897dSRiver Riddle return nullptr; 1058c60b897dSRiver Riddle 1059c60b897dSRiver Riddle /// Parse the values. 1060c60b897dSRiver Riddle auto valuesLoc = getToken().getLoc(); 1061c60b897dSRiver Riddle TensorLiteralParser valuesParser(*this); 1062c60b897dSRiver Riddle if (valuesParser.parse(/*allowHex=*/true)) 1063c60b897dSRiver Riddle return nullptr; 1064c60b897dSRiver Riddle 1065c60b897dSRiver Riddle if (parseToken(Token::greater, "expected '>'")) 1066c60b897dSRiver Riddle return nullptr; 1067c60b897dSRiver Riddle 1068c60b897dSRiver Riddle auto type = parseElementsLiteralType(attrType); 1069c60b897dSRiver Riddle if (!type) 1070c60b897dSRiver Riddle return nullptr; 1071c60b897dSRiver Riddle 1072c60b897dSRiver Riddle // If the indices are a splat, i.e. the literal parser parsed an element and 1073c60b897dSRiver Riddle // not a list, we set the shape explicitly. The indices are represented by a 1074c60b897dSRiver Riddle // 2-dimensional shape where the second dimension is the rank of the type. 1075c60b897dSRiver Riddle // Given that the parsed indices is a splat, we know that we only have one 1076c60b897dSRiver Riddle // indice and thus one for the first dimension. 1077c60b897dSRiver Riddle ShapedType indicesType; 1078c60b897dSRiver Riddle if (indiceParser.getShape().empty()) { 1079c60b897dSRiver Riddle indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType); 1080c60b897dSRiver Riddle } else { 1081c60b897dSRiver Riddle // Otherwise, set the shape to the one parsed by the literal parser. 1082c60b897dSRiver Riddle indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); 1083c60b897dSRiver Riddle } 1084c60b897dSRiver Riddle auto indices = indiceParser.getAttr(indicesLoc, indicesType); 1085c60b897dSRiver Riddle 1086c60b897dSRiver Riddle // If the values are a splat, set the shape explicitly based on the number of 1087c60b897dSRiver Riddle // indices. The number of indices is encoded in the first dimension of the 1088c60b897dSRiver Riddle // indice shape type. 1089c60b897dSRiver Riddle auto valuesEltType = type.getElementType(); 1090c60b897dSRiver Riddle ShapedType valuesType = 1091c60b897dSRiver Riddle valuesParser.getShape().empty() 1092c60b897dSRiver Riddle ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) 1093c60b897dSRiver Riddle : RankedTensorType::get(valuesParser.getShape(), valuesEltType); 1094c60b897dSRiver Riddle auto values = valuesParser.getAttr(valuesLoc, valuesType); 1095c60b897dSRiver Riddle 1096c60b897dSRiver Riddle // Build the sparse elements attribute by the indices and values. 1097c60b897dSRiver Riddle return getChecked<SparseElementsAttr>(loc, type, indices, values); 1098c60b897dSRiver Riddle } 1099519847feSAlex Zinenko 1100519847feSAlex Zinenko Attribute Parser::parseStridedLayoutAttr() { 1101519847feSAlex Zinenko // Callback for error emissing at the keyword token location. 1102519847feSAlex Zinenko llvm::SMLoc loc = getToken().getLoc(); 1103519847feSAlex Zinenko auto errorEmitter = [&] { return emitError(loc); }; 1104519847feSAlex Zinenko 1105519847feSAlex Zinenko consumeToken(Token::kw_strided); 1106519847feSAlex Zinenko if (failed(parseToken(Token::less, "expected '<' after 'strided'")) || 1107519847feSAlex Zinenko failed(parseToken(Token::l_square, "expected '['"))) 1108519847feSAlex Zinenko return nullptr; 1109519847feSAlex Zinenko 1110519847feSAlex Zinenko // Parses either an integer token or a question mark token. Reports an error 111170c73d1bSKazu Hirata // and returns std::nullopt if the current token is neither. The integer token 111270c73d1bSKazu Hirata // must fit into int64_t limits. 11130a81ace0SKazu Hirata auto parseStrideOrOffset = [&]() -> std::optional<int64_t> { 1114519847feSAlex Zinenko if (consumeIf(Token::question)) 1115399638f9SAliia Khasanova return ShapedType::kDynamic; 1116519847feSAlex Zinenko 1117519847feSAlex Zinenko SMLoc loc = getToken().getLoc(); 1118519847feSAlex Zinenko auto emitWrongTokenError = [&] { 111954d81e49SIvan Butygin emitError(loc, "expected a 64-bit signed integer or '?'"); 11201a36588eSKazu Hirata return std::nullopt; 1121519847feSAlex Zinenko }; 1122519847feSAlex Zinenko 112354d81e49SIvan Butygin bool negative = consumeIf(Token::minus); 112454d81e49SIvan Butygin 1125519847feSAlex Zinenko if (getToken().is(Token::integer)) { 11260a81ace0SKazu Hirata std::optional<uint64_t> value = getToken().getUInt64IntegerValue(); 11270fd95808SAlex Zinenko if (!value || 11280fd95808SAlex Zinenko *value > static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) 1129519847feSAlex Zinenko return emitWrongTokenError(); 1130519847feSAlex Zinenko consumeToken(); 113154d81e49SIvan Butygin auto result = static_cast<int64_t>(*value); 113254d81e49SIvan Butygin if (negative) 113354d81e49SIvan Butygin result = -result; 113454d81e49SIvan Butygin 113554d81e49SIvan Butygin return result; 1136519847feSAlex Zinenko } 1137519847feSAlex Zinenko 1138519847feSAlex Zinenko return emitWrongTokenError(); 1139519847feSAlex Zinenko }; 1140519847feSAlex Zinenko 1141519847feSAlex Zinenko // Parse strides. 1142519847feSAlex Zinenko SmallVector<int64_t> strides; 1143519847feSAlex Zinenko if (!getToken().is(Token::r_square)) { 1144519847feSAlex Zinenko do { 11450a81ace0SKazu Hirata std::optional<int64_t> stride = parseStrideOrOffset(); 1146519847feSAlex Zinenko if (!stride) 1147519847feSAlex Zinenko return nullptr; 1148519847feSAlex Zinenko strides.push_back(*stride); 1149519847feSAlex Zinenko } while (consumeIf(Token::comma)); 1150519847feSAlex Zinenko } 1151519847feSAlex Zinenko 1152519847feSAlex Zinenko if (failed(parseToken(Token::r_square, "expected ']'"))) 1153519847feSAlex Zinenko return nullptr; 1154519847feSAlex Zinenko 1155519847feSAlex Zinenko // Fast path in absence of offset. 1156519847feSAlex Zinenko if (consumeIf(Token::greater)) { 1157519847feSAlex Zinenko if (failed(StridedLayoutAttr::verify(errorEmitter, 1158519847feSAlex Zinenko /*offset=*/0, strides))) 1159519847feSAlex Zinenko return nullptr; 1160519847feSAlex Zinenko return StridedLayoutAttr::get(getContext(), /*offset=*/0, strides); 1161519847feSAlex Zinenko } 1162519847feSAlex Zinenko 1163519847feSAlex Zinenko if (failed(parseToken(Token::comma, "expected ','")) || 1164519847feSAlex Zinenko failed(parseToken(Token::kw_offset, "expected 'offset' after comma")) || 1165519847feSAlex Zinenko failed(parseToken(Token::colon, "expected ':' after 'offset'"))) 1166519847feSAlex Zinenko return nullptr; 1167519847feSAlex Zinenko 11680a81ace0SKazu Hirata std::optional<int64_t> offset = parseStrideOrOffset(); 1169519847feSAlex Zinenko if (!offset || failed(parseToken(Token::greater, "expected '>'"))) 1170519847feSAlex Zinenko return nullptr; 1171519847feSAlex Zinenko 1172519847feSAlex Zinenko if (failed(StridedLayoutAttr::verify(errorEmitter, *offset, strides))) 1173519847feSAlex Zinenko return nullptr; 1174519847feSAlex Zinenko return StridedLayoutAttr::get(getContext(), *offset, strides); 1175519847feSAlex Zinenko // return getChecked<StridedLayoutAttr>(loc,getContext(), *offset, strides); 1176519847feSAlex Zinenko } 1177728a8d5aSTobias Gysi 1178728a8d5aSTobias Gysi /// Parse a distinct attribute. 1179728a8d5aSTobias Gysi /// 1180728a8d5aSTobias Gysi /// distinct-attribute ::= `distinct` 1181728a8d5aSTobias Gysi /// `[` integer-literal `]<` attribute-value `>` 1182728a8d5aSTobias Gysi /// 1183728a8d5aSTobias Gysi Attribute Parser::parseDistinctAttr(Type type) { 118472307960SJacques Pienaar SMLoc loc = getToken().getLoc(); 1185728a8d5aSTobias Gysi consumeToken(Token::kw_distinct); 1186728a8d5aSTobias Gysi if (parseToken(Token::l_square, "expected '[' after 'distinct'")) 1187728a8d5aSTobias Gysi return {}; 1188728a8d5aSTobias Gysi 1189728a8d5aSTobias Gysi // Parse the distinct integer identifier. 1190728a8d5aSTobias Gysi Token token = getToken(); 1191728a8d5aSTobias Gysi if (parseToken(Token::integer, "expected distinct ID")) 1192728a8d5aSTobias Gysi return {}; 1193728a8d5aSTobias Gysi std::optional<uint64_t> value = token.getUInt64IntegerValue(); 1194728a8d5aSTobias Gysi if (!value) { 1195728a8d5aSTobias Gysi emitError("expected an unsigned 64-bit integer"); 1196728a8d5aSTobias Gysi return {}; 1197728a8d5aSTobias Gysi } 1198728a8d5aSTobias Gysi 1199728a8d5aSTobias Gysi // Parse the referenced attribute. 1200728a8d5aSTobias Gysi if (parseToken(Token::r_square, "expected ']' to close distinct ID") || 1201728a8d5aSTobias Gysi parseToken(Token::less, "expected '<' after distinct ID")) 1202728a8d5aSTobias Gysi return {}; 1203629460a9SMarkus Böck 1204629460a9SMarkus Böck Attribute referencedAttr; 1205629460a9SMarkus Böck if (getToken().is(Token::greater)) { 1206629460a9SMarkus Böck consumeToken(); 1207629460a9SMarkus Böck referencedAttr = builder.getUnitAttr(); 1208629460a9SMarkus Böck } else { 1209629460a9SMarkus Böck referencedAttr = parseAttribute(type); 1210728a8d5aSTobias Gysi if (!referencedAttr) { 1211728a8d5aSTobias Gysi emitError("expected attribute"); 1212728a8d5aSTobias Gysi return {}; 1213728a8d5aSTobias Gysi } 1214728a8d5aSTobias Gysi 1215629460a9SMarkus Böck if (parseToken(Token::greater, "expected '>' to close distinct attribute")) 1216629460a9SMarkus Böck return {}; 1217629460a9SMarkus Böck } 1218629460a9SMarkus Böck 1219728a8d5aSTobias Gysi // Add the distinct attribute to the parser state, if it has not been parsed 1220728a8d5aSTobias Gysi // before. Otherwise, check if the parsed reference attribute matches the one 1221728a8d5aSTobias Gysi // found in the parser state. 1222728a8d5aSTobias Gysi DenseMap<uint64_t, DistinctAttr> &distinctAttrs = 1223728a8d5aSTobias Gysi state.symbols.distinctAttributes; 1224728a8d5aSTobias Gysi auto it = distinctAttrs.find(*value); 1225728a8d5aSTobias Gysi if (it == distinctAttrs.end()) { 1226728a8d5aSTobias Gysi DistinctAttr distinctAttr = DistinctAttr::create(referencedAttr); 1227728a8d5aSTobias Gysi it = distinctAttrs.try_emplace(*value, distinctAttr).first; 1228728a8d5aSTobias Gysi } else if (it->getSecond().getReferencedAttr() != referencedAttr) { 122972307960SJacques Pienaar emitError(loc, "referenced attribute does not match previous definition: ") 1230728a8d5aSTobias Gysi << it->getSecond().getReferencedAttr(); 1231728a8d5aSTobias Gysi return {}; 1232728a8d5aSTobias Gysi } 1233728a8d5aSTobias Gysi 1234728a8d5aSTobias Gysi return it->getSecond(); 1235728a8d5aSTobias Gysi } 1236