xref: /llvm-project/mlir/lib/AsmParser/AttributeParser.cpp (revision 4548bff0e8139d4f375f1078dd50a74116eae0a2)
1c60b897dSRiver Riddle //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===//
2c60b897dSRiver Riddle //
3c60b897dSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c60b897dSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5c60b897dSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c60b897dSRiver Riddle //
7c60b897dSRiver Riddle //===----------------------------------------------------------------------===//
8c60b897dSRiver Riddle //
9c60b897dSRiver Riddle // This file implements the parser for the MLIR Types.
10c60b897dSRiver Riddle //
11c60b897dSRiver Riddle //===----------------------------------------------------------------------===//
12c60b897dSRiver Riddle 
13c60b897dSRiver Riddle #include "Parser.h"
14c60b897dSRiver Riddle 
15c60b897dSRiver Riddle #include "AsmParserImpl.h"
16c60b897dSRiver Riddle #include "mlir/AsmParser/AsmParserState.h"
17c60b897dSRiver Riddle #include "mlir/IR/AffineMap.h"
18519847feSAlex Zinenko #include "mlir/IR/BuiltinAttributes.h"
19995ab929SRiver Riddle #include "mlir/IR/BuiltinDialect.h"
20c60b897dSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
21c60b897dSRiver Riddle #include "mlir/IR/DialectImplementation.h"
22995ab929SRiver Riddle #include "mlir/IR/DialectResourceBlobManager.h"
23c60b897dSRiver Riddle #include "mlir/IR/IntegerSet.h"
24c60b897dSRiver Riddle #include "llvm/ADT/StringExtras.h"
25c60b897dSRiver Riddle #include "llvm/Support/Endian.h"
26d8be2081SKazu Hirata #include <optional>
27c60b897dSRiver Riddle 
28c60b897dSRiver Riddle using namespace mlir;
29c60b897dSRiver Riddle using namespace mlir::detail;
30c60b897dSRiver Riddle 
31c60b897dSRiver Riddle /// Parse an arbitrary attribute.
32c60b897dSRiver Riddle ///
33c60b897dSRiver Riddle ///  attribute-value ::= `unit`
34c60b897dSRiver Riddle ///                    | bool-literal
35c60b897dSRiver Riddle ///                    | integer-literal (`:` (index-type | integer-type))?
36c60b897dSRiver Riddle ///                    | float-literal (`:` float-type)?
37c60b897dSRiver Riddle ///                    | string-literal (`:` type)?
38c60b897dSRiver Riddle ///                    | type
39c60b897dSRiver Riddle ///                    | `[` `:` (integer-type | float-type) tensor-literal `]`
40c60b897dSRiver Riddle ///                    | `[` (attribute-value (`,` attribute-value)*)? `]`
41c60b897dSRiver Riddle ///                    | `{` (attribute-entry (`,` attribute-entry)*)? `}`
42c60b897dSRiver Riddle ///                    | symbol-ref-id (`::` symbol-ref-id)*
43c60b897dSRiver Riddle ///                    | `dense` `<` tensor-literal `>` `:`
44c60b897dSRiver Riddle ///                      (tensor-type | vector-type)
45c60b897dSRiver Riddle ///                    | `sparse` `<` attribute-value `,` attribute-value `>`
46c60b897dSRiver Riddle ///                      `:` (tensor-type | vector-type)
47519847feSAlex Zinenko ///                    | `strided` `<` `[` comma-separated-int-or-question `]`
48519847feSAlex Zinenko ///                      (`,` `offset` `:` integer-literal)? `>`
49728a8d5aSTobias Gysi ///                    | distinct-attribute
50c60b897dSRiver Riddle ///                    | extended-attribute
51c60b897dSRiver Riddle ///
52c60b897dSRiver Riddle Attribute Parser::parseAttribute(Type type) {
53c60b897dSRiver Riddle   switch (getToken().getKind()) {
54c60b897dSRiver Riddle   // Parse an AffineMap or IntegerSet attribute.
55c60b897dSRiver Riddle   case Token::kw_affine_map: {
56c60b897dSRiver Riddle     consumeToken(Token::kw_affine_map);
57c60b897dSRiver Riddle 
58c60b897dSRiver Riddle     AffineMap map;
59c60b897dSRiver Riddle     if (parseToken(Token::less, "expected '<' in affine map") ||
60c60b897dSRiver Riddle         parseAffineMapReference(map) ||
61c60b897dSRiver Riddle         parseToken(Token::greater, "expected '>' in affine map"))
62c60b897dSRiver Riddle       return Attribute();
63c60b897dSRiver Riddle     return AffineMapAttr::get(map);
64c60b897dSRiver Riddle   }
65c60b897dSRiver Riddle   case Token::kw_affine_set: {
66c60b897dSRiver Riddle     consumeToken(Token::kw_affine_set);
67c60b897dSRiver Riddle 
68c60b897dSRiver Riddle     IntegerSet set;
69c60b897dSRiver Riddle     if (parseToken(Token::less, "expected '<' in integer set") ||
70c60b897dSRiver Riddle         parseIntegerSetReference(set) ||
71c60b897dSRiver Riddle         parseToken(Token::greater, "expected '>' in integer set"))
72c60b897dSRiver Riddle       return Attribute();
73c60b897dSRiver Riddle     return IntegerSetAttr::get(set);
74c60b897dSRiver Riddle   }
75c60b897dSRiver Riddle 
76c60b897dSRiver Riddle   // Parse an array attribute.
77c60b897dSRiver Riddle   case Token::l_square: {
78c60b897dSRiver Riddle     consumeToken(Token::l_square);
79c60b897dSRiver Riddle     SmallVector<Attribute, 4> elements;
80c60b897dSRiver Riddle     auto parseElt = [&]() -> ParseResult {
81c60b897dSRiver Riddle       elements.push_back(parseAttribute());
82c60b897dSRiver Riddle       return elements.back() ? success() : failure();
83c60b897dSRiver Riddle     };
84c60b897dSRiver Riddle 
85c60b897dSRiver Riddle     if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
86c60b897dSRiver Riddle       return nullptr;
87c60b897dSRiver Riddle     return builder.getArrayAttr(elements);
88c60b897dSRiver Riddle   }
89c60b897dSRiver Riddle 
90c60b897dSRiver Riddle   // Parse a boolean attribute.
91c60b897dSRiver Riddle   case Token::kw_false:
92c60b897dSRiver Riddle     consumeToken(Token::kw_false);
93c60b897dSRiver Riddle     return builder.getBoolAttr(false);
94c60b897dSRiver Riddle   case Token::kw_true:
95c60b897dSRiver Riddle     consumeToken(Token::kw_true);
96c60b897dSRiver Riddle     return builder.getBoolAttr(true);
97c60b897dSRiver Riddle 
98c60b897dSRiver Riddle   // Parse a dense elements attribute.
99c60b897dSRiver Riddle   case Token::kw_dense:
100c60b897dSRiver Riddle     return parseDenseElementsAttr(type);
101c60b897dSRiver Riddle 
102995ab929SRiver Riddle   // Parse a dense resource elements attribute.
103995ab929SRiver Riddle   case Token::kw_dense_resource:
104995ab929SRiver Riddle     return parseDenseResourceElementsAttr(type);
105995ab929SRiver Riddle 
1062092d143SJeff Niu   // Parse a dense array attribute.
1072092d143SJeff Niu   case Token::kw_array:
1082092d143SJeff Niu     return parseDenseArrayAttr(type);
1092092d143SJeff Niu 
110c60b897dSRiver Riddle   // Parse a dictionary attribute.
111c60b897dSRiver Riddle   case Token::l_brace: {
112c60b897dSRiver Riddle     NamedAttrList elements;
113c60b897dSRiver Riddle     if (parseAttributeDict(elements))
114c60b897dSRiver Riddle       return nullptr;
115c60b897dSRiver Riddle     return elements.getDictionary(getContext());
116c60b897dSRiver Riddle   }
117c60b897dSRiver Riddle 
118c60b897dSRiver Riddle   // Parse an extended attribute, i.e. alias or dialect attribute.
119c60b897dSRiver Riddle   case Token::hash_identifier:
120c60b897dSRiver Riddle     return parseExtendedAttr(type);
121c60b897dSRiver Riddle 
122c60b897dSRiver Riddle   // Parse floating point and integer attributes.
123c60b897dSRiver Riddle   case Token::floatliteral:
124c60b897dSRiver Riddle     return parseFloatAttr(type, /*isNegative=*/false);
125c60b897dSRiver Riddle   case Token::integer:
126c60b897dSRiver Riddle     return parseDecOrHexAttr(type, /*isNegative=*/false);
127c60b897dSRiver Riddle   case Token::minus: {
128c60b897dSRiver Riddle     consumeToken(Token::minus);
129c60b897dSRiver Riddle     if (getToken().is(Token::integer))
130c60b897dSRiver Riddle       return parseDecOrHexAttr(type, /*isNegative=*/true);
131c60b897dSRiver Riddle     if (getToken().is(Token::floatliteral))
132c60b897dSRiver Riddle       return parseFloatAttr(type, /*isNegative=*/true);
133c60b897dSRiver Riddle 
134c60b897dSRiver Riddle     return (emitWrongTokenError(
135c60b897dSRiver Riddle                 "expected constant integer or floating point value"),
136c60b897dSRiver Riddle             nullptr);
137c60b897dSRiver Riddle   }
138c60b897dSRiver Riddle 
139c60b897dSRiver Riddle   // Parse a location attribute.
140c60b897dSRiver Riddle   case Token::kw_loc: {
141c60b897dSRiver Riddle     consumeToken(Token::kw_loc);
142c60b897dSRiver Riddle 
143c60b897dSRiver Riddle     LocationAttr locAttr;
144c60b897dSRiver Riddle     if (parseToken(Token::l_paren, "expected '(' in inline location") ||
145c60b897dSRiver Riddle         parseLocationInstance(locAttr) ||
146c60b897dSRiver Riddle         parseToken(Token::r_paren, "expected ')' in inline location"))
147c60b897dSRiver Riddle       return Attribute();
148c60b897dSRiver Riddle     return locAttr;
149c60b897dSRiver Riddle   }
150c60b897dSRiver Riddle 
151c60b897dSRiver Riddle   // Parse a sparse elements attribute.
152c60b897dSRiver Riddle   case Token::kw_sparse:
153c60b897dSRiver Riddle     return parseSparseElementsAttr(type);
154c60b897dSRiver Riddle 
155519847feSAlex Zinenko   // Parse a strided layout attribute.
156519847feSAlex Zinenko   case Token::kw_strided:
157519847feSAlex Zinenko     return parseStridedLayoutAttr();
158519847feSAlex Zinenko 
159728a8d5aSTobias Gysi   // Parse a distinct attribute.
160728a8d5aSTobias Gysi   case Token::kw_distinct:
161728a8d5aSTobias Gysi     return parseDistinctAttr(type);
162728a8d5aSTobias Gysi 
163c60b897dSRiver Riddle   // Parse a string attribute.
164c60b897dSRiver Riddle   case Token::string: {
165c60b897dSRiver Riddle     auto val = getToken().getStringValue();
166c60b897dSRiver Riddle     consumeToken(Token::string);
167c60b897dSRiver Riddle     // Parse the optional trailing colon type if one wasn't explicitly provided.
168c60b897dSRiver Riddle     if (!type && consumeIf(Token::colon) && !(type = parseType()))
169c60b897dSRiver Riddle       return Attribute();
170c60b897dSRiver Riddle 
171c60b897dSRiver Riddle     return type ? StringAttr::get(val, type)
172c60b897dSRiver Riddle                 : StringAttr::get(getContext(), val);
173c60b897dSRiver Riddle   }
174c60b897dSRiver Riddle 
175c60b897dSRiver Riddle   // Parse a symbol reference attribute.
176c60b897dSRiver Riddle   case Token::at_identifier: {
177c60b897dSRiver Riddle     // When populating the parser state, this is a list of locations for all of
178c60b897dSRiver Riddle     // the nested references.
179c60b897dSRiver Riddle     SmallVector<SMRange> referenceLocations;
180c60b897dSRiver Riddle     if (state.asmState)
181c60b897dSRiver Riddle       referenceLocations.push_back(getToken().getLocRange());
182c60b897dSRiver Riddle 
183c60b897dSRiver Riddle     // Parse the top-level reference.
184c60b897dSRiver Riddle     std::string nameStr = getToken().getSymbolReference();
185c60b897dSRiver Riddle     consumeToken(Token::at_identifier);
186c60b897dSRiver Riddle 
187c60b897dSRiver Riddle     // Parse any nested references.
188c60b897dSRiver Riddle     std::vector<FlatSymbolRefAttr> nestedRefs;
189c60b897dSRiver Riddle     while (getToken().is(Token::colon)) {
190c60b897dSRiver Riddle       // Check for the '::' prefix.
191c60b897dSRiver Riddle       const char *curPointer = getToken().getLoc().getPointer();
192c60b897dSRiver Riddle       consumeToken(Token::colon);
193c60b897dSRiver Riddle       if (!consumeIf(Token::colon)) {
194c60b897dSRiver Riddle         if (getToken().isNot(Token::eof, Token::error)) {
195c60b897dSRiver Riddle           state.lex.resetPointer(curPointer);
196c60b897dSRiver Riddle           consumeToken();
197c60b897dSRiver Riddle         }
198c60b897dSRiver Riddle         break;
199c60b897dSRiver Riddle       }
200c60b897dSRiver Riddle       // Parse the reference itself.
201c60b897dSRiver Riddle       auto curLoc = getToken().getLoc();
202c60b897dSRiver Riddle       if (getToken().isNot(Token::at_identifier)) {
203c60b897dSRiver Riddle         emitError(curLoc, "expected nested symbol reference identifier");
204c60b897dSRiver Riddle         return Attribute();
205c60b897dSRiver Riddle       }
206c60b897dSRiver Riddle 
207c60b897dSRiver Riddle       // If we are populating the assembly state, add the location for this
208c60b897dSRiver Riddle       // reference.
209c60b897dSRiver Riddle       if (state.asmState)
210c60b897dSRiver Riddle         referenceLocations.push_back(getToken().getLocRange());
211c60b897dSRiver Riddle 
212c60b897dSRiver Riddle       std::string nameStr = getToken().getSymbolReference();
213c60b897dSRiver Riddle       consumeToken(Token::at_identifier);
214c60b897dSRiver Riddle       nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
215c60b897dSRiver Riddle     }
216c60b897dSRiver Riddle     SymbolRefAttr symbolRefAttr =
217c60b897dSRiver Riddle         SymbolRefAttr::get(getContext(), nameStr, nestedRefs);
218c60b897dSRiver Riddle 
219c60b897dSRiver Riddle     // If we are populating the assembly state, record this symbol reference.
220c60b897dSRiver Riddle     if (state.asmState)
221c60b897dSRiver Riddle       state.asmState->addUses(symbolRefAttr, referenceLocations);
222c60b897dSRiver Riddle     return symbolRefAttr;
223c60b897dSRiver Riddle   }
224c60b897dSRiver Riddle 
225c60b897dSRiver Riddle   // Parse a 'unit' attribute.
226c60b897dSRiver Riddle   case Token::kw_unit:
227c60b897dSRiver Riddle     consumeToken(Token::kw_unit);
228c60b897dSRiver Riddle     return builder.getUnitAttr();
229c60b897dSRiver Riddle 
230c60b897dSRiver Riddle     // Handle completion of an attribute.
231c60b897dSRiver Riddle   case Token::code_complete:
232c60b897dSRiver Riddle     if (getToken().isCodeCompletionFor(Token::hash_identifier))
233c60b897dSRiver Riddle       return parseExtendedAttr(type);
234c60b897dSRiver Riddle     return codeCompleteAttribute();
235c60b897dSRiver Riddle 
236c60b897dSRiver Riddle   default:
237c60b897dSRiver Riddle     // Parse a type attribute. We parse `Optional` here to allow for providing a
238c60b897dSRiver Riddle     // better error message.
239c60b897dSRiver Riddle     Type type;
240c60b897dSRiver Riddle     OptionalParseResult result = parseOptionalType(type);
2419750648cSKazu Hirata     if (!result.has_value())
242c60b897dSRiver Riddle       return emitWrongTokenError("expected attribute value"), Attribute();
243c60b897dSRiver Riddle     return failed(*result) ? Attribute() : TypeAttr::get(type);
244c60b897dSRiver Riddle   }
245c60b897dSRiver Riddle }
246c60b897dSRiver Riddle 
247c60b897dSRiver Riddle /// Parse an optional attribute with the provided type.
248c60b897dSRiver Riddle OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
249c60b897dSRiver Riddle                                                    Type type) {
250c60b897dSRiver Riddle   switch (getToken().getKind()) {
251c60b897dSRiver Riddle   case Token::at_identifier:
252c60b897dSRiver Riddle   case Token::floatliteral:
253c60b897dSRiver Riddle   case Token::integer:
254c60b897dSRiver Riddle   case Token::hash_identifier:
255c60b897dSRiver Riddle   case Token::kw_affine_map:
256c60b897dSRiver Riddle   case Token::kw_affine_set:
257c60b897dSRiver Riddle   case Token::kw_dense:
258995ab929SRiver Riddle   case Token::kw_dense_resource:
259c60b897dSRiver Riddle   case Token::kw_false:
260c60b897dSRiver Riddle   case Token::kw_loc:
261c60b897dSRiver Riddle   case Token::kw_sparse:
262c60b897dSRiver Riddle   case Token::kw_true:
263c60b897dSRiver Riddle   case Token::kw_unit:
264c60b897dSRiver Riddle   case Token::l_brace:
265c60b897dSRiver Riddle   case Token::l_square:
266c60b897dSRiver Riddle   case Token::minus:
267c60b897dSRiver Riddle   case Token::string:
268c60b897dSRiver Riddle     attribute = parseAttribute(type);
269c60b897dSRiver Riddle     return success(attribute != nullptr);
270c60b897dSRiver Riddle 
271c60b897dSRiver Riddle   default:
272c60b897dSRiver Riddle     // Parse an optional type attribute.
273c60b897dSRiver Riddle     Type type;
274c60b897dSRiver Riddle     OptionalParseResult result = parseOptionalType(type);
2759750648cSKazu Hirata     if (result.has_value() && succeeded(*result))
276c60b897dSRiver Riddle       attribute = TypeAttr::get(type);
277c60b897dSRiver Riddle     return result;
278c60b897dSRiver Riddle   }
279c60b897dSRiver Riddle }
280c60b897dSRiver Riddle OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
281c60b897dSRiver Riddle                                                    Type type) {
282c60b897dSRiver Riddle   return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
283c60b897dSRiver Riddle }
284c60b897dSRiver Riddle OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
285c60b897dSRiver Riddle                                                    Type type) {
286c60b897dSRiver Riddle   return parseOptionalAttributeWithToken(Token::string, attribute, type);
287c60b897dSRiver Riddle }
2888bb8421bSRiver Riddle OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result,
2898bb8421bSRiver Riddle                                                    Type type) {
2908bb8421bSRiver Riddle   return parseOptionalAttributeWithToken(Token::at_identifier, result, type);
2918bb8421bSRiver Riddle }
292c60b897dSRiver Riddle 
293c60b897dSRiver Riddle /// Attribute dictionary.
294c60b897dSRiver Riddle ///
295c60b897dSRiver Riddle ///   attribute-dict ::= `{` `}`
296c60b897dSRiver Riddle ///                    | `{` attribute-entry (`,` attribute-entry)* `}`
297c60b897dSRiver Riddle ///   attribute-entry ::= (bare-id | string-literal) `=` attribute-value
298c60b897dSRiver Riddle ///
299c60b897dSRiver Riddle ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
300c60b897dSRiver Riddle   llvm::SmallDenseSet<StringAttr> seenKeys;
301c60b897dSRiver Riddle   auto parseElt = [&]() -> ParseResult {
302c60b897dSRiver Riddle     // The name of an attribute can either be a bare identifier, or a string.
303d8be2081SKazu Hirata     std::optional<StringAttr> nameId;
304c60b897dSRiver Riddle     if (getToken().is(Token::string))
305c60b897dSRiver Riddle       nameId = builder.getStringAttr(getToken().getStringValue());
306c60b897dSRiver Riddle     else if (getToken().isAny(Token::bare_identifier, Token::inttype) ||
307c60b897dSRiver Riddle              getToken().isKeyword())
308c60b897dSRiver Riddle       nameId = builder.getStringAttr(getTokenSpelling());
309c60b897dSRiver Riddle     else
310c60b897dSRiver Riddle       return emitWrongTokenError("expected attribute name");
311c60b897dSRiver Riddle 
312398e48a7SAdrian Kuegel     if (nameId->empty())
313c60b897dSRiver Riddle       return emitError("expected valid attribute name");
314c60b897dSRiver Riddle 
315c60b897dSRiver Riddle     if (!seenKeys.insert(*nameId).second)
316c60b897dSRiver Riddle       return emitError("duplicate key '")
317c60b897dSRiver Riddle              << nameId->getValue() << "' in dictionary attribute";
318c60b897dSRiver Riddle     consumeToken();
319c60b897dSRiver Riddle 
320c60b897dSRiver Riddle     // Lazy load a dialect in the context if there is a possible namespace.
321c60b897dSRiver Riddle     auto splitName = nameId->strref().split('.');
322c60b897dSRiver Riddle     if (!splitName.second.empty())
323c60b897dSRiver Riddle       getContext()->getOrLoadDialect(splitName.first);
324c60b897dSRiver Riddle 
325c60b897dSRiver Riddle     // Try to parse the '=' for the attribute value.
326c60b897dSRiver Riddle     if (!consumeIf(Token::equal)) {
327c60b897dSRiver Riddle       // If there is no '=', we treat this as a unit attribute.
328c60b897dSRiver Riddle       attributes.push_back({*nameId, builder.getUnitAttr()});
329c60b897dSRiver Riddle       return success();
330c60b897dSRiver Riddle     }
331c60b897dSRiver Riddle 
332c60b897dSRiver Riddle     auto attr = parseAttribute();
333c60b897dSRiver Riddle     if (!attr)
334c60b897dSRiver Riddle       return failure();
335c60b897dSRiver Riddle     attributes.push_back({*nameId, attr});
336c60b897dSRiver Riddle     return success();
337c60b897dSRiver Riddle   };
338c60b897dSRiver Riddle 
339c60b897dSRiver Riddle   return parseCommaSeparatedList(Delimiter::Braces, parseElt,
340c60b897dSRiver Riddle                                  " in attribute dictionary");
341c60b897dSRiver Riddle }
342c60b897dSRiver Riddle 
343c60b897dSRiver Riddle /// Parse a float attribute.
344c60b897dSRiver Riddle Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
345c60b897dSRiver Riddle   auto val = getToken().getFloatingPointValue();
346c60b897dSRiver Riddle   if (!val)
347c60b897dSRiver Riddle     return (emitError("floating point value too large for attribute"), nullptr);
348c60b897dSRiver Riddle   consumeToken(Token::floatliteral);
349c60b897dSRiver Riddle   if (!type) {
350c60b897dSRiver Riddle     // Default to F64 when no type is specified.
351c60b897dSRiver Riddle     if (!consumeIf(Token::colon))
352c60b897dSRiver Riddle       type = builder.getF64Type();
353c60b897dSRiver Riddle     else if (!(type = parseType()))
354c60b897dSRiver Riddle       return nullptr;
355c60b897dSRiver Riddle   }
3565550c821STres Popp   if (!isa<FloatType>(type))
357c60b897dSRiver Riddle     return (emitError("floating point value not valid for specified type"),
358c60b897dSRiver Riddle             nullptr);
359c60b897dSRiver Riddle   return FloatAttr::get(type, isNegative ? -*val : *val);
360c60b897dSRiver Riddle }
361c60b897dSRiver Riddle 
362c60b897dSRiver Riddle /// Construct an APint from a parsed value, a known attribute type and
363c60b897dSRiver Riddle /// sign.
3640a81ace0SKazu Hirata static std::optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
365c60b897dSRiver Riddle                                                 StringRef spelling) {
366c60b897dSRiver Riddle   // Parse the integer value into an APInt that is big enough to hold the value.
367c60b897dSRiver Riddle   APInt result;
368c60b897dSRiver Riddle   bool isHex = spelling.size() > 1 && spelling[1] == 'x';
369c60b897dSRiver Riddle   if (spelling.getAsInteger(isHex ? 0 : 10, result))
3701a36588eSKazu Hirata     return std::nullopt;
371c60b897dSRiver Riddle 
372c60b897dSRiver Riddle   // Extend or truncate the bitwidth to the right size.
373c60b897dSRiver Riddle   unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
374c60b897dSRiver Riddle                                   : type.getIntOrFloatBitWidth();
375c60b897dSRiver Riddle 
376c60b897dSRiver Riddle   if (width > result.getBitWidth()) {
377c60b897dSRiver Riddle     result = result.zext(width);
378c60b897dSRiver Riddle   } else if (width < result.getBitWidth()) {
379c60b897dSRiver Riddle     // The parser can return an unnecessarily wide result with leading zeros.
380c60b897dSRiver Riddle     // This isn't a problem, but truncating off bits is bad.
381f8f3db27SKazu Hirata     if (result.countl_zero() < result.getBitWidth() - width)
3821a36588eSKazu Hirata       return std::nullopt;
383c60b897dSRiver Riddle 
384c60b897dSRiver Riddle     result = result.trunc(width);
385c60b897dSRiver Riddle   }
386c60b897dSRiver Riddle 
387c60b897dSRiver Riddle   if (width == 0) {
388c60b897dSRiver Riddle     // 0 bit integers cannot be negative and manipulation of their sign bit will
389c60b897dSRiver Riddle     // assert, so short-cut validation here.
390c60b897dSRiver Riddle     if (isNegative)
3911a36588eSKazu Hirata       return std::nullopt;
392c60b897dSRiver Riddle   } else if (isNegative) {
393c60b897dSRiver Riddle     // The value is negative, we have an overflow if the sign bit is not set
394c60b897dSRiver Riddle     // in the negated apInt.
395c60b897dSRiver Riddle     result.negate();
396c60b897dSRiver Riddle     if (!result.isSignBitSet())
3971a36588eSKazu Hirata       return std::nullopt;
398c60b897dSRiver Riddle   } else if ((type.isSignedInteger() || type.isIndex()) &&
399c60b897dSRiver Riddle              result.isSignBitSet()) {
400c60b897dSRiver Riddle     // The value is a positive signed integer or index,
401c60b897dSRiver Riddle     // we have an overflow if the sign bit is set.
4021a36588eSKazu Hirata     return std::nullopt;
403c60b897dSRiver Riddle   }
404c60b897dSRiver Riddle 
405c60b897dSRiver Riddle   return result;
406c60b897dSRiver Riddle }
407c60b897dSRiver Riddle 
408c60b897dSRiver Riddle /// Parse a decimal or a hexadecimal literal, which can be either an integer
409c60b897dSRiver Riddle /// or a float attribute.
410c60b897dSRiver Riddle Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
411c60b897dSRiver Riddle   Token tok = getToken();
412c60b897dSRiver Riddle   StringRef spelling = tok.getSpelling();
413c60b897dSRiver Riddle   SMLoc loc = tok.getLoc();
414c60b897dSRiver Riddle 
415c60b897dSRiver Riddle   consumeToken(Token::integer);
416c60b897dSRiver Riddle   if (!type) {
417c60b897dSRiver Riddle     // Default to i64 if not type is specified.
418c60b897dSRiver Riddle     if (!consumeIf(Token::colon))
419c60b897dSRiver Riddle       type = builder.getIntegerType(64);
420c60b897dSRiver Riddle     else if (!(type = parseType()))
421c60b897dSRiver Riddle       return nullptr;
422c60b897dSRiver Riddle   }
423c60b897dSRiver Riddle 
4245550c821STres Popp   if (auto floatType = dyn_cast<FloatType>(type)) {
4250a81ace0SKazu Hirata     std::optional<APFloat> result;
426c60b897dSRiver Riddle     if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
427*4548bff0SMatthias Springer                                             floatType.getFloatSemantics())))
428c60b897dSRiver Riddle       return Attribute();
429c60b897dSRiver Riddle     return FloatAttr::get(floatType, *result);
430c60b897dSRiver Riddle   }
431c60b897dSRiver Riddle 
4325550c821STres Popp   if (!isa<IntegerType, IndexType>(type))
433c60b897dSRiver Riddle     return emitError(loc, "integer literal not valid for specified type"),
434c60b897dSRiver Riddle            nullptr;
435c60b897dSRiver Riddle 
436c60b897dSRiver Riddle   if (isNegative && type.isUnsignedInteger()) {
437c60b897dSRiver Riddle     emitError(loc,
438c60b897dSRiver Riddle               "negative integer literal not valid for unsigned integer type");
439c60b897dSRiver Riddle     return nullptr;
440c60b897dSRiver Riddle   }
441c60b897dSRiver Riddle 
4420a81ace0SKazu Hirata   std::optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
443c60b897dSRiver Riddle   if (!apInt)
444c60b897dSRiver Riddle     return emitError(loc, "integer constant out of range for attribute"),
445c60b897dSRiver Riddle            nullptr;
446c60b897dSRiver Riddle   return builder.getIntegerAttr(type, *apInt);
447c60b897dSRiver Riddle }
448c60b897dSRiver Riddle 
449c60b897dSRiver Riddle //===----------------------------------------------------------------------===//
450c60b897dSRiver Riddle // TensorLiteralParser
451c60b897dSRiver Riddle //===----------------------------------------------------------------------===//
452c60b897dSRiver Riddle 
453c60b897dSRiver Riddle /// Parse elements values stored within a hex string. On success, the values are
454c60b897dSRiver Riddle /// stored into 'result'.
455c60b897dSRiver Riddle static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
456c60b897dSRiver Riddle                                              std::string &result) {
4570a81ace0SKazu Hirata   if (std::optional<std::string> value = tok.getHexStringValue()) {
458c60b897dSRiver Riddle     result = std::move(*value);
459c60b897dSRiver Riddle     return success();
460c60b897dSRiver Riddle   }
461c60b897dSRiver Riddle   return parser.emitError(
462c60b897dSRiver Riddle       tok.getLoc(), "expected string containing hex digits starting with `0x`");
463c60b897dSRiver Riddle }
464c60b897dSRiver Riddle 
465c60b897dSRiver Riddle namespace {
466c60b897dSRiver Riddle /// This class implements a parser for TensorLiterals. A tensor literal is
467c60b897dSRiver Riddle /// either a single element (e.g, 5) or a multi-dimensional list of elements
468c60b897dSRiver Riddle /// (e.g., [[5, 5]]).
469c60b897dSRiver Riddle class TensorLiteralParser {
470c60b897dSRiver Riddle public:
471c60b897dSRiver Riddle   TensorLiteralParser(Parser &p) : p(p) {}
472c60b897dSRiver Riddle 
473c60b897dSRiver Riddle   /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
474c60b897dSRiver Riddle   /// may also parse a tensor literal that is store as a hex string.
475c60b897dSRiver Riddle   ParseResult parse(bool allowHex);
476c60b897dSRiver Riddle 
477c60b897dSRiver Riddle   /// Build a dense attribute instance with the parsed elements and the given
478c60b897dSRiver Riddle   /// shaped type.
479c60b897dSRiver Riddle   DenseElementsAttr getAttr(SMLoc loc, ShapedType type);
480c60b897dSRiver Riddle 
481c60b897dSRiver Riddle   ArrayRef<int64_t> getShape() const { return shape; }
482c60b897dSRiver Riddle 
483c60b897dSRiver Riddle private:
484c60b897dSRiver Riddle   /// Get the parsed elements for an integer attribute.
485c60b897dSRiver Riddle   ParseResult getIntAttrElements(SMLoc loc, Type eltTy,
486c60b897dSRiver Riddle                                  std::vector<APInt> &intValues);
487c60b897dSRiver Riddle 
488c60b897dSRiver Riddle   /// Get the parsed elements for a float attribute.
489c60b897dSRiver Riddle   ParseResult getFloatAttrElements(SMLoc loc, FloatType eltTy,
490c60b897dSRiver Riddle                                    std::vector<APFloat> &floatValues);
491c60b897dSRiver Riddle 
492c60b897dSRiver Riddle   /// Build a Dense String attribute for the given type.
493c60b897dSRiver Riddle   DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy);
494c60b897dSRiver Riddle 
495c60b897dSRiver Riddle   /// Build a Dense attribute with hex data for the given type.
496c60b897dSRiver Riddle   DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type);
497c60b897dSRiver Riddle 
498c60b897dSRiver Riddle   /// Parse a single element, returning failure if it isn't a valid element
499c60b897dSRiver Riddle   /// literal. For example:
500c60b897dSRiver Riddle   /// parseElement(1) -> Success, 1
501c60b897dSRiver Riddle   /// parseElement([1]) -> Failure
502c60b897dSRiver Riddle   ParseResult parseElement();
503c60b897dSRiver Riddle 
504c60b897dSRiver Riddle   /// Parse a list of either lists or elements, returning the dimensions of the
505c60b897dSRiver Riddle   /// parsed sub-tensors in dims. For example:
506c60b897dSRiver Riddle   ///   parseList([1, 2, 3]) -> Success, [3]
507c60b897dSRiver Riddle   ///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
508c60b897dSRiver Riddle   ///   parseList([[1, 2], 3]) -> Failure
509c60b897dSRiver Riddle   ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
510c60b897dSRiver Riddle   ParseResult parseList(SmallVectorImpl<int64_t> &dims);
511c60b897dSRiver Riddle 
512c60b897dSRiver Riddle   /// Parse a literal that was printed as a hex string.
513c60b897dSRiver Riddle   ParseResult parseHexElements();
514c60b897dSRiver Riddle 
515c60b897dSRiver Riddle   Parser &p;
516c60b897dSRiver Riddle 
517c60b897dSRiver Riddle   /// The shape inferred from the parsed elements.
518c60b897dSRiver Riddle   SmallVector<int64_t, 4> shape;
519c60b897dSRiver Riddle 
520c60b897dSRiver Riddle   /// Storage used when parsing elements, this is a pair of <is_negated, token>.
521c60b897dSRiver Riddle   std::vector<std::pair<bool, Token>> storage;
522c60b897dSRiver Riddle 
523c60b897dSRiver Riddle   /// Storage used when parsing elements that were stored as hex values.
524d8be2081SKazu Hirata   std::optional<Token> hexStorage;
525c60b897dSRiver Riddle };
526c60b897dSRiver Riddle } // namespace
527c60b897dSRiver Riddle 
528c60b897dSRiver Riddle /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
529c60b897dSRiver Riddle /// may also parse a tensor literal that is store as a hex string.
530c60b897dSRiver Riddle ParseResult TensorLiteralParser::parse(bool allowHex) {
531c60b897dSRiver Riddle   // If hex is allowed, check for a string literal.
532c60b897dSRiver Riddle   if (allowHex && p.getToken().is(Token::string)) {
533c60b897dSRiver Riddle     hexStorage = p.getToken();
534c60b897dSRiver Riddle     p.consumeToken(Token::string);
535c60b897dSRiver Riddle     return success();
536c60b897dSRiver Riddle   }
537c60b897dSRiver Riddle   // Otherwise, parse a list or an individual element.
538c60b897dSRiver Riddle   if (p.getToken().is(Token::l_square))
539c60b897dSRiver Riddle     return parseList(shape);
540c60b897dSRiver Riddle   return parseElement();
541c60b897dSRiver Riddle }
542c60b897dSRiver Riddle 
543c60b897dSRiver Riddle /// Build a dense attribute instance with the parsed elements and the given
544c60b897dSRiver Riddle /// shaped type.
545c60b897dSRiver Riddle DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
546c60b897dSRiver Riddle   Type eltType = type.getElementType();
547c60b897dSRiver Riddle 
548c60b897dSRiver Riddle   // Check to see if we parse the literal from a hex string.
549c60b897dSRiver Riddle   if (hexStorage &&
5505550c821STres Popp       (eltType.isIntOrIndexOrFloat() || isa<ComplexType>(eltType)))
551c60b897dSRiver Riddle     return getHexAttr(loc, type);
552c60b897dSRiver Riddle 
553c60b897dSRiver Riddle   // Check that the parsed storage size has the same number of elements to the
554c60b897dSRiver Riddle   // type, or is a known splat.
555c60b897dSRiver Riddle   if (!shape.empty() && getShape() != type.getShape()) {
556c60b897dSRiver Riddle     p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
557c60b897dSRiver Riddle                      << "]) does not match type ([" << type.getShape() << "])";
558c60b897dSRiver Riddle     return nullptr;
559c60b897dSRiver Riddle   }
560c60b897dSRiver Riddle 
561c60b897dSRiver Riddle   // Handle the case where no elements were parsed.
562c60b897dSRiver Riddle   if (!hexStorage && storage.empty() && type.getNumElements()) {
563c60b897dSRiver Riddle     p.emitError(loc) << "parsed zero elements, but type (" << type
564c60b897dSRiver Riddle                      << ") expected at least 1";
565c60b897dSRiver Riddle     return nullptr;
566c60b897dSRiver Riddle   }
567c60b897dSRiver Riddle 
568c60b897dSRiver Riddle   // Handle complex types in the specific element type cases below.
569c60b897dSRiver Riddle   bool isComplex = false;
5705550c821STres Popp   if (ComplexType complexTy = dyn_cast<ComplexType>(eltType)) {
571c60b897dSRiver Riddle     eltType = complexTy.getElementType();
572c60b897dSRiver Riddle     isComplex = true;
573c60b897dSRiver Riddle   }
574c60b897dSRiver Riddle 
575c60b897dSRiver Riddle   // Handle integer and index types.
576c60b897dSRiver Riddle   if (eltType.isIntOrIndex()) {
577c60b897dSRiver Riddle     std::vector<APInt> intValues;
578c60b897dSRiver Riddle     if (failed(getIntAttrElements(loc, eltType, intValues)))
579c60b897dSRiver Riddle       return nullptr;
580c60b897dSRiver Riddle     if (isComplex) {
581c60b897dSRiver Riddle       // If this is a complex, treat the parsed values as complex values.
582984b800aSserge-sans-paille       auto complexData = llvm::ArrayRef(
583c60b897dSRiver Riddle           reinterpret_cast<std::complex<APInt> *>(intValues.data()),
584c60b897dSRiver Riddle           intValues.size() / 2);
585c60b897dSRiver Riddle       return DenseElementsAttr::get(type, complexData);
586c60b897dSRiver Riddle     }
587c60b897dSRiver Riddle     return DenseElementsAttr::get(type, intValues);
588c60b897dSRiver Riddle   }
589c60b897dSRiver Riddle   // Handle floating point types.
5905550c821STres Popp   if (FloatType floatTy = dyn_cast<FloatType>(eltType)) {
591c60b897dSRiver Riddle     std::vector<APFloat> floatValues;
592c60b897dSRiver Riddle     if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
593c60b897dSRiver Riddle       return nullptr;
594c60b897dSRiver Riddle     if (isComplex) {
595c60b897dSRiver Riddle       // If this is a complex, treat the parsed values as complex values.
596984b800aSserge-sans-paille       auto complexData = llvm::ArrayRef(
597c60b897dSRiver Riddle           reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
598c60b897dSRiver Riddle           floatValues.size() / 2);
599c60b897dSRiver Riddle       return DenseElementsAttr::get(type, complexData);
600c60b897dSRiver Riddle     }
601c60b897dSRiver Riddle     return DenseElementsAttr::get(type, floatValues);
602c60b897dSRiver Riddle   }
603c60b897dSRiver Riddle 
604c60b897dSRiver Riddle   // Other types are assumed to be string representations.
605c60b897dSRiver Riddle   return getStringAttr(loc, type, type.getElementType());
606c60b897dSRiver Riddle }
607c60b897dSRiver Riddle 
608c60b897dSRiver Riddle /// Build a Dense Integer attribute for the given type.
609c60b897dSRiver Riddle ParseResult
610c60b897dSRiver Riddle TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
611c60b897dSRiver Riddle                                         std::vector<APInt> &intValues) {
612c60b897dSRiver Riddle   intValues.reserve(storage.size());
613c60b897dSRiver Riddle   bool isUintType = eltTy.isUnsignedInteger();
614c60b897dSRiver Riddle   for (const auto &signAndToken : storage) {
615c60b897dSRiver Riddle     bool isNegative = signAndToken.first;
616c60b897dSRiver Riddle     const Token &token = signAndToken.second;
617c60b897dSRiver Riddle     auto tokenLoc = token.getLoc();
618c60b897dSRiver Riddle 
619c60b897dSRiver Riddle     if (isNegative && isUintType) {
620c60b897dSRiver Riddle       return p.emitError(tokenLoc)
621c60b897dSRiver Riddle              << "expected unsigned integer elements, but parsed negative value";
622c60b897dSRiver Riddle     }
623c60b897dSRiver Riddle 
624c60b897dSRiver Riddle     // Check to see if floating point values were parsed.
625c60b897dSRiver Riddle     if (token.is(Token::floatliteral)) {
626c60b897dSRiver Riddle       return p.emitError(tokenLoc)
627c60b897dSRiver Riddle              << "expected integer elements, but parsed floating-point";
628c60b897dSRiver Riddle     }
629c60b897dSRiver Riddle 
630c60b897dSRiver Riddle     assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
631c60b897dSRiver Riddle            "unexpected token type");
632c60b897dSRiver Riddle     if (token.isAny(Token::kw_true, Token::kw_false)) {
633c60b897dSRiver Riddle       if (!eltTy.isInteger(1)) {
634c60b897dSRiver Riddle         return p.emitError(tokenLoc)
635c60b897dSRiver Riddle                << "expected i1 type for 'true' or 'false' values";
636c60b897dSRiver Riddle       }
637c60b897dSRiver Riddle       APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
638c60b897dSRiver Riddle       intValues.push_back(apInt);
639c60b897dSRiver Riddle       continue;
640c60b897dSRiver Riddle     }
641c60b897dSRiver Riddle 
642c60b897dSRiver Riddle     // Create APInt values for each element with the correct bitwidth.
6430a81ace0SKazu Hirata     std::optional<APInt> apInt =
644c60b897dSRiver Riddle         buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
645c60b897dSRiver Riddle     if (!apInt)
646c60b897dSRiver Riddle       return p.emitError(tokenLoc, "integer constant out of range for type");
647c60b897dSRiver Riddle     intValues.push_back(*apInt);
648c60b897dSRiver Riddle   }
649c60b897dSRiver Riddle   return success();
650c60b897dSRiver Riddle }
651c60b897dSRiver Riddle 
652c60b897dSRiver Riddle /// Build a Dense Float attribute for the given type.
653c60b897dSRiver Riddle ParseResult
654c60b897dSRiver Riddle TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
655c60b897dSRiver Riddle                                           std::vector<APFloat> &floatValues) {
656c60b897dSRiver Riddle   floatValues.reserve(storage.size());
657c60b897dSRiver Riddle   for (const auto &signAndToken : storage) {
658c60b897dSRiver Riddle     bool isNegative = signAndToken.first;
659c60b897dSRiver Riddle     const Token &token = signAndToken.second;
6600a81ace0SKazu Hirata     std::optional<APFloat> result;
661*4548bff0SMatthias Springer     if (failed(p.parseFloatFromLiteral(result, token, isNegative,
662*4548bff0SMatthias Springer                                        eltTy.getFloatSemantics())))
663c60b897dSRiver Riddle       return failure();
664c60b897dSRiver Riddle     floatValues.push_back(*result);
665c60b897dSRiver Riddle   }
666c60b897dSRiver Riddle   return success();
667c60b897dSRiver Riddle }
668c60b897dSRiver Riddle 
669c60b897dSRiver Riddle /// Build a Dense String attribute for the given type.
670c60b897dSRiver Riddle DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type,
671c60b897dSRiver Riddle                                                      Type eltTy) {
672c60b897dSRiver Riddle   if (hexStorage.has_value()) {
6734913e5daSFangrui Song     auto stringValue = hexStorage->getStringValue();
674c60b897dSRiver Riddle     return DenseStringElementsAttr::get(type, {stringValue});
675c60b897dSRiver Riddle   }
676c60b897dSRiver Riddle 
677c60b897dSRiver Riddle   std::vector<std::string> stringValues;
678c60b897dSRiver Riddle   std::vector<StringRef> stringRefValues;
679c60b897dSRiver Riddle   stringValues.reserve(storage.size());
680c60b897dSRiver Riddle   stringRefValues.reserve(storage.size());
681c60b897dSRiver Riddle 
682c60b897dSRiver Riddle   for (auto val : storage) {
683c60b897dSRiver Riddle     stringValues.push_back(val.second.getStringValue());
684c60b897dSRiver Riddle     stringRefValues.emplace_back(stringValues.back());
685c60b897dSRiver Riddle   }
686c60b897dSRiver Riddle 
687c60b897dSRiver Riddle   return DenseStringElementsAttr::get(type, stringRefValues);
688c60b897dSRiver Riddle }
689c60b897dSRiver Riddle 
690c60b897dSRiver Riddle /// Build a Dense attribute with hex data for the given type.
691c60b897dSRiver Riddle DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) {
692c60b897dSRiver Riddle   Type elementType = type.getElementType();
6935550c821STres Popp   if (!elementType.isIntOrIndexOrFloat() && !isa<ComplexType>(elementType)) {
694c60b897dSRiver Riddle     p.emitError(loc)
695c60b897dSRiver Riddle         << "expected floating-point, integer, or complex element type, got "
696c60b897dSRiver Riddle         << elementType;
697c60b897dSRiver Riddle     return nullptr;
698c60b897dSRiver Riddle   }
699c60b897dSRiver Riddle 
700c60b897dSRiver Riddle   std::string data;
701c60b897dSRiver Riddle   if (parseElementAttrHexValues(p, *hexStorage, data))
702c60b897dSRiver Riddle     return nullptr;
703c60b897dSRiver Riddle 
704c60b897dSRiver Riddle   ArrayRef<char> rawData(data.data(), data.size());
705c60b897dSRiver Riddle   bool detectedSplat = false;
706c60b897dSRiver Riddle   if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
707c60b897dSRiver Riddle     p.emitError(loc) << "elements hex data size is invalid for provided type: "
708c60b897dSRiver Riddle                      << type;
709c60b897dSRiver Riddle     return nullptr;
710c60b897dSRiver Riddle   }
711c60b897dSRiver Riddle 
7126b31b026SKazu Hirata   if (llvm::endianness::native == llvm::endianness::big) {
713c60b897dSRiver Riddle     // Convert endianess in big-endian(BE) machines. `rawData` is
714c60b897dSRiver Riddle     // little-endian(LE) because HEX in raw data of dense element attribute
715c60b897dSRiver Riddle     // is always LE format. It is converted into BE here to be used in BE
716c60b897dSRiver Riddle     // machines.
717c60b897dSRiver Riddle     SmallVector<char, 64> outDataVec(rawData.size());
718c60b897dSRiver Riddle     MutableArrayRef<char> convRawData(outDataVec);
719c60b897dSRiver Riddle     DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
720c60b897dSRiver Riddle         rawData, convRawData, type);
721c60b897dSRiver Riddle     return DenseElementsAttr::getFromRawBuffer(type, convRawData);
722c60b897dSRiver Riddle   }
723c60b897dSRiver Riddle 
724c60b897dSRiver Riddle   return DenseElementsAttr::getFromRawBuffer(type, rawData);
725c60b897dSRiver Riddle }
726c60b897dSRiver Riddle 
727c60b897dSRiver Riddle ParseResult TensorLiteralParser::parseElement() {
728c60b897dSRiver Riddle   switch (p.getToken().getKind()) {
729c60b897dSRiver Riddle   // Parse a boolean element.
730c60b897dSRiver Riddle   case Token::kw_true:
731c60b897dSRiver Riddle   case Token::kw_false:
732c60b897dSRiver Riddle   case Token::floatliteral:
733c60b897dSRiver Riddle   case Token::integer:
734c60b897dSRiver Riddle     storage.emplace_back(/*isNegative=*/false, p.getToken());
735c60b897dSRiver Riddle     p.consumeToken();
736c60b897dSRiver Riddle     break;
737c60b897dSRiver Riddle 
738c60b897dSRiver Riddle   // Parse a signed integer or a negative floating-point element.
739c60b897dSRiver Riddle   case Token::minus:
740c60b897dSRiver Riddle     p.consumeToken(Token::minus);
741c60b897dSRiver Riddle     if (!p.getToken().isAny(Token::floatliteral, Token::integer))
742c60b897dSRiver Riddle       return p.emitError("expected integer or floating point literal");
743c60b897dSRiver Riddle     storage.emplace_back(/*isNegative=*/true, p.getToken());
744c60b897dSRiver Riddle     p.consumeToken();
745c60b897dSRiver Riddle     break;
746c60b897dSRiver Riddle 
747c60b897dSRiver Riddle   case Token::string:
748c60b897dSRiver Riddle     storage.emplace_back(/*isNegative=*/false, p.getToken());
749c60b897dSRiver Riddle     p.consumeToken();
750c60b897dSRiver Riddle     break;
751c60b897dSRiver Riddle 
752c60b897dSRiver Riddle   // Parse a complex element of the form '(' element ',' element ')'.
753c60b897dSRiver Riddle   case Token::l_paren:
754c60b897dSRiver Riddle     p.consumeToken(Token::l_paren);
755c60b897dSRiver Riddle     if (parseElement() ||
756c60b897dSRiver Riddle         p.parseToken(Token::comma, "expected ',' between complex elements") ||
757c60b897dSRiver Riddle         parseElement() ||
758c60b897dSRiver Riddle         p.parseToken(Token::r_paren, "expected ')' after complex elements"))
759c60b897dSRiver Riddle       return failure();
760c60b897dSRiver Riddle     break;
761c60b897dSRiver Riddle 
762c60b897dSRiver Riddle   default:
763c60b897dSRiver Riddle     return p.emitError("expected element literal of primitive type");
764c60b897dSRiver Riddle   }
765c60b897dSRiver Riddle 
766c60b897dSRiver Riddle   return success();
767c60b897dSRiver Riddle }
768c60b897dSRiver Riddle 
769c60b897dSRiver Riddle /// Parse a list of either lists or elements, returning the dimensions of the
770c60b897dSRiver Riddle /// parsed sub-tensors in dims. For example:
771c60b897dSRiver Riddle ///   parseList([1, 2, 3]) -> Success, [3]
772c60b897dSRiver Riddle ///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
773c60b897dSRiver Riddle ///   parseList([[1, 2], 3]) -> Failure
774c60b897dSRiver Riddle ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
775c60b897dSRiver Riddle ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
776c60b897dSRiver Riddle   auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
777c60b897dSRiver Riddle                        const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
778c60b897dSRiver Riddle     if (prevDims == newDims)
779c60b897dSRiver Riddle       return success();
780c60b897dSRiver Riddle     return p.emitError("tensor literal is invalid; ranks are not consistent "
781c60b897dSRiver Riddle                        "between elements");
782c60b897dSRiver Riddle   };
783c60b897dSRiver Riddle 
784c60b897dSRiver Riddle   bool first = true;
785c60b897dSRiver Riddle   SmallVector<int64_t, 4> newDims;
786c60b897dSRiver Riddle   unsigned size = 0;
787c60b897dSRiver Riddle   auto parseOneElement = [&]() -> ParseResult {
788c60b897dSRiver Riddle     SmallVector<int64_t, 4> thisDims;
789c60b897dSRiver Riddle     if (p.getToken().getKind() == Token::l_square) {
790c60b897dSRiver Riddle       if (parseList(thisDims))
791c60b897dSRiver Riddle         return failure();
792c60b897dSRiver Riddle     } else if (parseElement()) {
793c60b897dSRiver Riddle       return failure();
794c60b897dSRiver Riddle     }
795c60b897dSRiver Riddle     ++size;
796c60b897dSRiver Riddle     if (!first)
797c60b897dSRiver Riddle       return checkDims(newDims, thisDims);
798c60b897dSRiver Riddle     newDims = thisDims;
799c60b897dSRiver Riddle     first = false;
800c60b897dSRiver Riddle     return success();
801c60b897dSRiver Riddle   };
802c60b897dSRiver Riddle   if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOneElement))
803c60b897dSRiver Riddle     return failure();
804c60b897dSRiver Riddle 
805c60b897dSRiver Riddle   // Return the sublists' dimensions with 'size' prepended.
806c60b897dSRiver Riddle   dims.clear();
807c60b897dSRiver Riddle   dims.push_back(size);
808c60b897dSRiver Riddle   dims.append(newDims.begin(), newDims.end());
809c60b897dSRiver Riddle   return success();
810c60b897dSRiver Riddle }
811c60b897dSRiver Riddle 
812c60b897dSRiver Riddle //===----------------------------------------------------------------------===//
813cec7e80eSJeff Niu // DenseArrayAttr Parser
814c60b897dSRiver Riddle //===----------------------------------------------------------------------===//
815c60b897dSRiver Riddle 
816c60b897dSRiver Riddle namespace {
817cec7e80eSJeff Niu /// A generic dense array element parser. It parsers integer and floating point
818cec7e80eSJeff Niu /// elements.
819cec7e80eSJeff Niu class DenseArrayElementParser {
820c60b897dSRiver Riddle public:
821cec7e80eSJeff Niu   explicit DenseArrayElementParser(Type type) : type(type) {}
822cec7e80eSJeff Niu 
823cec7e80eSJeff Niu   /// Parse an integer element.
824cec7e80eSJeff Niu   ParseResult parseIntegerElement(Parser &p);
825cec7e80eSJeff Niu 
826cec7e80eSJeff Niu   /// Parse a floating point element.
827cec7e80eSJeff Niu   ParseResult parseFloatElement(Parser &p);
828cec7e80eSJeff Niu 
829cec7e80eSJeff Niu   /// Convert the current contents to a dense array.
830c48e0cf0SJeff Niu   DenseArrayAttr getAttr() { return DenseArrayAttr::get(type, size, rawData); }
831cec7e80eSJeff Niu 
832cec7e80eSJeff Niu private:
833cec7e80eSJeff Niu   /// Append the raw data of an APInt to the result.
834cec7e80eSJeff Niu   void append(const APInt &data);
835cec7e80eSJeff Niu 
836cec7e80eSJeff Niu   /// The array element type.
837cec7e80eSJeff Niu   Type type;
838cec7e80eSJeff Niu   /// The resultant byte array representing the contents of the array.
839cec7e80eSJeff Niu   std::vector<char> rawData;
840cec7e80eSJeff Niu   /// The number of elements in the array.
841cec7e80eSJeff Niu   int64_t size = 0;
842c60b897dSRiver Riddle };
843c60b897dSRiver Riddle } // namespace
844c60b897dSRiver Riddle 
845cec7e80eSJeff Niu void DenseArrayElementParser::append(const APInt &data) {
846cbf81558SVitaly Buka   if (data.getBitWidth()) {
847cbf81558SVitaly Buka     assert(data.getBitWidth() % 8 == 0);
848cec7e80eSJeff Niu     unsigned byteSize = data.getBitWidth() / 8;
849cec7e80eSJeff Niu     size_t offset = rawData.size();
850cec7e80eSJeff Niu     rawData.insert(rawData.end(), byteSize, 0);
851cec7e80eSJeff Niu     llvm::StoreIntToMemory(
852cec7e80eSJeff Niu         data, reinterpret_cast<uint8_t *>(rawData.data() + offset), byteSize);
853cbf81558SVitaly Buka   }
854cec7e80eSJeff Niu   ++size;
855cec7e80eSJeff Niu }
856cec7e80eSJeff Niu 
857cec7e80eSJeff Niu ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
858cec7e80eSJeff Niu   bool isNegative = p.consumeIf(Token::minus);
859cec7e80eSJeff Niu 
860cec7e80eSJeff Niu   // Parse an integer literal as an APInt.
8610a81ace0SKazu Hirata   std::optional<APInt> value;
862cec7e80eSJeff Niu   StringRef spelling = p.getToken().getSpelling();
863cec7e80eSJeff Niu   if (p.getToken().isAny(Token::kw_true, Token::kw_false)) {
864cec7e80eSJeff Niu     if (!type.isInteger(1))
865cec7e80eSJeff Niu       return p.emitError("expected i1 type for 'true' or 'false' values");
866cec7e80eSJeff Niu     value = APInt(/*numBits=*/8, p.getToken().is(Token::kw_true),
867cec7e80eSJeff Niu                   !type.isUnsignedInteger());
868cec7e80eSJeff Niu     p.consumeToken();
869cec7e80eSJeff Niu   } else if (p.consumeIf(Token::integer)) {
870cec7e80eSJeff Niu     value = buildAttributeAPInt(type, isNegative, spelling);
871cec7e80eSJeff Niu     if (!value)
872cec7e80eSJeff Niu       return p.emitError("integer constant out of range");
873cec7e80eSJeff Niu   } else {
874cec7e80eSJeff Niu     return p.emitError("expected integer literal");
875cec7e80eSJeff Niu   }
876cec7e80eSJeff Niu   append(*value);
877cec7e80eSJeff Niu   return success();
878cec7e80eSJeff Niu }
879cec7e80eSJeff Niu 
880cec7e80eSJeff Niu ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
881cec7e80eSJeff Niu   bool isNegative = p.consumeIf(Token::minus);
882cec7e80eSJeff Niu   Token token = p.getToken();
883*4548bff0SMatthias Springer   std::optional<APFloat> fromIntLit;
884*4548bff0SMatthias Springer   if (failed(
885*4548bff0SMatthias Springer           p.parseFloatFromLiteral(fromIntLit, token, isNegative,
886*4548bff0SMatthias Springer                                   cast<FloatType>(type).getFloatSemantics())))
887cec7e80eSJeff Niu     return failure();
888*4548bff0SMatthias Springer   p.consumeToken();
889*4548bff0SMatthias Springer   append(fromIntLit->bitcastToAPInt());
890cec7e80eSJeff Niu   return success();
891cec7e80eSJeff Niu }
892cec7e80eSJeff Niu 
893c60b897dSRiver Riddle /// Parse a dense array attribute.
8947a7c0697SJeff Niu Attribute Parser::parseDenseArrayAttr(Type attrType) {
8952092d143SJeff Niu   consumeToken(Token::kw_array);
896cec7e80eSJeff Niu   if (parseToken(Token::less, "expected '<' after 'array'"))
89796da738dSJeff Niu     return {};
898c60b897dSRiver Riddle 
899cec7e80eSJeff Niu   SMLoc typeLoc = getToken().getLoc();
900c48e0cf0SJeff Niu   Type eltType = parseType();
901c48e0cf0SJeff Niu   if (!eltType) {
902c48e0cf0SJeff Niu     emitError(typeLoc, "expected an integer or floating point type");
9037a7c0697SJeff Niu     return {};
9047a7c0697SJeff Niu   }
9057a7c0697SJeff Niu 
9067a7c0697SJeff Niu   // Only bool or integer and floating point elements divisible by bytes are
9077a7c0697SJeff Niu   // supported.
9087a7c0697SJeff Niu   if (!eltType.isIntOrIndexOrFloat()) {
9097a7c0697SJeff Niu     emitError(typeLoc, "expected integer or float type, got: ") << eltType;
9107a7c0697SJeff Niu     return {};
9117a7c0697SJeff Niu   }
9127a7c0697SJeff Niu   if (!eltType.isInteger(1) && eltType.getIntOrFloatBitWidth() % 8 != 0) {
913cec7e80eSJeff Niu     emitError(typeLoc, "element type bitwidth must be a multiple of 8");
914cec7e80eSJeff Niu     return {};
915cec7e80eSJeff Niu   }
916cec7e80eSJeff Niu 
917cec7e80eSJeff Niu   // Check for empty list.
918c48e0cf0SJeff Niu   if (consumeIf(Token::greater))
919c48e0cf0SJeff Niu     return DenseArrayAttr::get(eltType, 0, {});
920c48e0cf0SJeff Niu 
921c48e0cf0SJeff Niu   if (parseToken(Token::colon, "expected ':' after dense array type"))
922cec7e80eSJeff Niu     return {};
923cec7e80eSJeff Niu 
9247a7c0697SJeff Niu   DenseArrayElementParser eltParser(eltType);
9257a7c0697SJeff Niu   if (eltType.isIntOrIndex()) {
926cec7e80eSJeff Niu     if (parseCommaSeparatedList(
927cec7e80eSJeff Niu             [&] { return eltParser.parseIntegerElement(*this); }))
928cec7e80eSJeff Niu       return {};
929cec7e80eSJeff Niu   } else {
930cec7e80eSJeff Niu     if (parseCommaSeparatedList(
931cec7e80eSJeff Niu             [&] { return eltParser.parseFloatElement(*this); }))
932cec7e80eSJeff Niu       return {};
933cec7e80eSJeff Niu   }
9342092d143SJeff Niu   if (parseToken(Token::greater, "expected '>' to close an array attribute"))
935c60b897dSRiver Riddle     return {};
936c48e0cf0SJeff Niu   return eltParser.getAttr();
937c60b897dSRiver Riddle }
938c60b897dSRiver Riddle 
939c60b897dSRiver Riddle /// Parse a dense elements attribute.
940c60b897dSRiver Riddle Attribute Parser::parseDenseElementsAttr(Type attrType) {
941c60b897dSRiver Riddle   auto attribLoc = getToken().getLoc();
942c60b897dSRiver Riddle   consumeToken(Token::kw_dense);
943c60b897dSRiver Riddle   if (parseToken(Token::less, "expected '<' after 'dense'"))
944c60b897dSRiver Riddle     return nullptr;
945c60b897dSRiver Riddle 
946c60b897dSRiver Riddle   // Parse the literal data if necessary.
947c60b897dSRiver Riddle   TensorLiteralParser literalParser(*this);
948c60b897dSRiver Riddle   if (!consumeIf(Token::greater)) {
949c60b897dSRiver Riddle     if (literalParser.parse(/*allowHex=*/true) ||
950c60b897dSRiver Riddle         parseToken(Token::greater, "expected '>'"))
951c60b897dSRiver Riddle       return nullptr;
952c60b897dSRiver Riddle   }
953c60b897dSRiver Riddle 
954c60b897dSRiver Riddle   // If the type is specified `parseElementsLiteralType` will not parse a type.
955c60b897dSRiver Riddle   // Use the attribute location as the location for error reporting in that
956c60b897dSRiver Riddle   // case.
957c60b897dSRiver Riddle   auto loc = attrType ? attribLoc : getToken().getLoc();
958c60b897dSRiver Riddle   auto type = parseElementsLiteralType(attrType);
959c60b897dSRiver Riddle   if (!type)
960c60b897dSRiver Riddle     return nullptr;
961c60b897dSRiver Riddle   return literalParser.getAttr(loc, type);
962c60b897dSRiver Riddle }
963c60b897dSRiver Riddle 
964995ab929SRiver Riddle Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
965995ab929SRiver Riddle   auto loc = getToken().getLoc();
966995ab929SRiver Riddle   consumeToken(Token::kw_dense_resource);
967995ab929SRiver Riddle   if (parseToken(Token::less, "expected '<' after 'dense_resource'"))
968995ab929SRiver Riddle     return nullptr;
969995ab929SRiver Riddle 
970995ab929SRiver Riddle   // Parse the resource handle.
971995ab929SRiver Riddle   FailureOr<AsmDialectResourceHandle> rawHandle =
972995ab929SRiver Riddle       parseResourceHandle(getContext()->getLoadedDialect<BuiltinDialect>());
973995ab929SRiver Riddle   if (failed(rawHandle) || parseToken(Token::greater, "expected '>'"))
974995ab929SRiver Riddle     return nullptr;
975995ab929SRiver Riddle 
976995ab929SRiver Riddle   auto *handle = dyn_cast<DenseResourceElementsHandle>(&*rawHandle);
977995ab929SRiver Riddle   if (!handle)
978995ab929SRiver Riddle     return emitError(loc, "invalid `dense_resource` handle type"), nullptr;
979995ab929SRiver Riddle 
980995ab929SRiver Riddle   // Parse the type of the attribute if the user didn't provide one.
981995ab929SRiver Riddle   SMLoc typeLoc = loc;
982995ab929SRiver Riddle   if (!attrType) {
983995ab929SRiver Riddle     typeLoc = getToken().getLoc();
984995ab929SRiver Riddle     if (parseToken(Token::colon, "expected ':'") || !(attrType = parseType()))
985995ab929SRiver Riddle       return nullptr;
986995ab929SRiver Riddle   }
987995ab929SRiver Riddle 
9885550c821STres Popp   ShapedType shapedType = dyn_cast<ShapedType>(attrType);
989995ab929SRiver Riddle   if (!shapedType) {
990995ab929SRiver Riddle     emitError(typeLoc, "`dense_resource` expected a shaped type");
991995ab929SRiver Riddle     return nullptr;
992995ab929SRiver Riddle   }
993995ab929SRiver Riddle 
994995ab929SRiver Riddle   return DenseResourceElementsAttr::get(shapedType, *handle);
995995ab929SRiver Riddle }
996995ab929SRiver Riddle 
997c60b897dSRiver Riddle /// Shaped type for elements attribute.
998c60b897dSRiver Riddle ///
999c60b897dSRiver Riddle ///   elements-literal-type ::= vector-type | ranked-tensor-type
1000c60b897dSRiver Riddle ///
1001c60b897dSRiver Riddle /// This method also checks the type has static shape.
1002c60b897dSRiver Riddle ShapedType Parser::parseElementsLiteralType(Type type) {
1003c60b897dSRiver Riddle   // If the user didn't provide a type, parse the colon type for the literal.
1004c60b897dSRiver Riddle   if (!type) {
1005c60b897dSRiver Riddle     if (parseToken(Token::colon, "expected ':'"))
1006c60b897dSRiver Riddle       return nullptr;
1007c60b897dSRiver Riddle     if (!(type = parseType()))
1008c60b897dSRiver Riddle       return nullptr;
1009c60b897dSRiver Riddle   }
1010c60b897dSRiver Riddle 
10115550c821STres Popp   auto sType = dyn_cast<ShapedType>(type);
1012a5c46bf9SJeff Niu   if (!sType) {
1013a5c46bf9SJeff Niu     emitError("elements literal must be a shaped type");
1014c60b897dSRiver Riddle     return nullptr;
1015c60b897dSRiver Riddle   }
1016c60b897dSRiver Riddle 
1017c60b897dSRiver Riddle   if (!sType.hasStaticShape())
1018c60b897dSRiver Riddle     return (emitError("elements literal type must have static shape"), nullptr);
1019c60b897dSRiver Riddle 
1020c60b897dSRiver Riddle   return sType;
1021c60b897dSRiver Riddle }
1022c60b897dSRiver Riddle 
1023c60b897dSRiver Riddle /// Parse a sparse elements attribute.
1024c60b897dSRiver Riddle Attribute Parser::parseSparseElementsAttr(Type attrType) {
1025c60b897dSRiver Riddle   SMLoc loc = getToken().getLoc();
1026c60b897dSRiver Riddle   consumeToken(Token::kw_sparse);
1027c60b897dSRiver Riddle   if (parseToken(Token::less, "Expected '<' after 'sparse'"))
1028c60b897dSRiver Riddle     return nullptr;
1029c60b897dSRiver Riddle 
1030c60b897dSRiver Riddle   // Check for the case where all elements are sparse. The indices are
1031c60b897dSRiver Riddle   // represented by a 2-dimensional shape where the second dimension is the rank
1032c60b897dSRiver Riddle   // of the type.
1033c60b897dSRiver Riddle   Type indiceEltType = builder.getIntegerType(64);
1034c60b897dSRiver Riddle   if (consumeIf(Token::greater)) {
1035c60b897dSRiver Riddle     ShapedType type = parseElementsLiteralType(attrType);
1036c60b897dSRiver Riddle     if (!type)
1037c60b897dSRiver Riddle       return nullptr;
1038c60b897dSRiver Riddle 
1039c60b897dSRiver Riddle     // Construct the sparse elements attr using zero element indice/value
1040c60b897dSRiver Riddle     // attributes.
1041c60b897dSRiver Riddle     ShapedType indicesType =
1042c60b897dSRiver Riddle         RankedTensorType::get({0, type.getRank()}, indiceEltType);
1043c60b897dSRiver Riddle     ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
1044c60b897dSRiver Riddle     return getChecked<SparseElementsAttr>(
1045c60b897dSRiver Riddle         loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
1046c60b897dSRiver Riddle         DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
1047c60b897dSRiver Riddle   }
1048c60b897dSRiver Riddle 
1049c60b897dSRiver Riddle   /// Parse the indices. We don't allow hex values here as we may need to use
1050c60b897dSRiver Riddle   /// the inferred shape.
1051c60b897dSRiver Riddle   auto indicesLoc = getToken().getLoc();
1052c60b897dSRiver Riddle   TensorLiteralParser indiceParser(*this);
1053c60b897dSRiver Riddle   if (indiceParser.parse(/*allowHex=*/false))
1054c60b897dSRiver Riddle     return nullptr;
1055c60b897dSRiver Riddle 
1056c60b897dSRiver Riddle   if (parseToken(Token::comma, "expected ','"))
1057c60b897dSRiver Riddle     return nullptr;
1058c60b897dSRiver Riddle 
1059c60b897dSRiver Riddle   /// Parse the values.
1060c60b897dSRiver Riddle   auto valuesLoc = getToken().getLoc();
1061c60b897dSRiver Riddle   TensorLiteralParser valuesParser(*this);
1062c60b897dSRiver Riddle   if (valuesParser.parse(/*allowHex=*/true))
1063c60b897dSRiver Riddle     return nullptr;
1064c60b897dSRiver Riddle 
1065c60b897dSRiver Riddle   if (parseToken(Token::greater, "expected '>'"))
1066c60b897dSRiver Riddle     return nullptr;
1067c60b897dSRiver Riddle 
1068c60b897dSRiver Riddle   auto type = parseElementsLiteralType(attrType);
1069c60b897dSRiver Riddle   if (!type)
1070c60b897dSRiver Riddle     return nullptr;
1071c60b897dSRiver Riddle 
1072c60b897dSRiver Riddle   // If the indices are a splat, i.e. the literal parser parsed an element and
1073c60b897dSRiver Riddle   // not a list, we set the shape explicitly. The indices are represented by a
1074c60b897dSRiver Riddle   // 2-dimensional shape where the second dimension is the rank of the type.
1075c60b897dSRiver Riddle   // Given that the parsed indices is a splat, we know that we only have one
1076c60b897dSRiver Riddle   // indice and thus one for the first dimension.
1077c60b897dSRiver Riddle   ShapedType indicesType;
1078c60b897dSRiver Riddle   if (indiceParser.getShape().empty()) {
1079c60b897dSRiver Riddle     indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
1080c60b897dSRiver Riddle   } else {
1081c60b897dSRiver Riddle     // Otherwise, set the shape to the one parsed by the literal parser.
1082c60b897dSRiver Riddle     indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
1083c60b897dSRiver Riddle   }
1084c60b897dSRiver Riddle   auto indices = indiceParser.getAttr(indicesLoc, indicesType);
1085c60b897dSRiver Riddle 
1086c60b897dSRiver Riddle   // If the values are a splat, set the shape explicitly based on the number of
1087c60b897dSRiver Riddle   // indices. The number of indices is encoded in the first dimension of the
1088c60b897dSRiver Riddle   // indice shape type.
1089c60b897dSRiver Riddle   auto valuesEltType = type.getElementType();
1090c60b897dSRiver Riddle   ShapedType valuesType =
1091c60b897dSRiver Riddle       valuesParser.getShape().empty()
1092c60b897dSRiver Riddle           ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
1093c60b897dSRiver Riddle           : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
1094c60b897dSRiver Riddle   auto values = valuesParser.getAttr(valuesLoc, valuesType);
1095c60b897dSRiver Riddle 
1096c60b897dSRiver Riddle   // Build the sparse elements attribute by the indices and values.
1097c60b897dSRiver Riddle   return getChecked<SparseElementsAttr>(loc, type, indices, values);
1098c60b897dSRiver Riddle }
1099519847feSAlex Zinenko 
1100519847feSAlex Zinenko Attribute Parser::parseStridedLayoutAttr() {
1101519847feSAlex Zinenko   // Callback for error emissing at the keyword token location.
1102519847feSAlex Zinenko   llvm::SMLoc loc = getToken().getLoc();
1103519847feSAlex Zinenko   auto errorEmitter = [&] { return emitError(loc); };
1104519847feSAlex Zinenko 
1105519847feSAlex Zinenko   consumeToken(Token::kw_strided);
1106519847feSAlex Zinenko   if (failed(parseToken(Token::less, "expected '<' after 'strided'")) ||
1107519847feSAlex Zinenko       failed(parseToken(Token::l_square, "expected '['")))
1108519847feSAlex Zinenko     return nullptr;
1109519847feSAlex Zinenko 
1110519847feSAlex Zinenko   // Parses either an integer token or a question mark token. Reports an error
111170c73d1bSKazu Hirata   // and returns std::nullopt if the current token is neither. The integer token
111270c73d1bSKazu Hirata   // must fit into int64_t limits.
11130a81ace0SKazu Hirata   auto parseStrideOrOffset = [&]() -> std::optional<int64_t> {
1114519847feSAlex Zinenko     if (consumeIf(Token::question))
1115399638f9SAliia Khasanova       return ShapedType::kDynamic;
1116519847feSAlex Zinenko 
1117519847feSAlex Zinenko     SMLoc loc = getToken().getLoc();
1118519847feSAlex Zinenko     auto emitWrongTokenError = [&] {
111954d81e49SIvan Butygin       emitError(loc, "expected a 64-bit signed integer or '?'");
11201a36588eSKazu Hirata       return std::nullopt;
1121519847feSAlex Zinenko     };
1122519847feSAlex Zinenko 
112354d81e49SIvan Butygin     bool negative = consumeIf(Token::minus);
112454d81e49SIvan Butygin 
1125519847feSAlex Zinenko     if (getToken().is(Token::integer)) {
11260a81ace0SKazu Hirata       std::optional<uint64_t> value = getToken().getUInt64IntegerValue();
11270fd95808SAlex Zinenko       if (!value ||
11280fd95808SAlex Zinenko           *value > static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
1129519847feSAlex Zinenko         return emitWrongTokenError();
1130519847feSAlex Zinenko       consumeToken();
113154d81e49SIvan Butygin       auto result = static_cast<int64_t>(*value);
113254d81e49SIvan Butygin       if (negative)
113354d81e49SIvan Butygin         result = -result;
113454d81e49SIvan Butygin 
113554d81e49SIvan Butygin       return result;
1136519847feSAlex Zinenko     }
1137519847feSAlex Zinenko 
1138519847feSAlex Zinenko     return emitWrongTokenError();
1139519847feSAlex Zinenko   };
1140519847feSAlex Zinenko 
1141519847feSAlex Zinenko   // Parse strides.
1142519847feSAlex Zinenko   SmallVector<int64_t> strides;
1143519847feSAlex Zinenko   if (!getToken().is(Token::r_square)) {
1144519847feSAlex Zinenko     do {
11450a81ace0SKazu Hirata       std::optional<int64_t> stride = parseStrideOrOffset();
1146519847feSAlex Zinenko       if (!stride)
1147519847feSAlex Zinenko         return nullptr;
1148519847feSAlex Zinenko       strides.push_back(*stride);
1149519847feSAlex Zinenko     } while (consumeIf(Token::comma));
1150519847feSAlex Zinenko   }
1151519847feSAlex Zinenko 
1152519847feSAlex Zinenko   if (failed(parseToken(Token::r_square, "expected ']'")))
1153519847feSAlex Zinenko     return nullptr;
1154519847feSAlex Zinenko 
1155519847feSAlex Zinenko   // Fast path in absence of offset.
1156519847feSAlex Zinenko   if (consumeIf(Token::greater)) {
1157519847feSAlex Zinenko     if (failed(StridedLayoutAttr::verify(errorEmitter,
1158519847feSAlex Zinenko                                          /*offset=*/0, strides)))
1159519847feSAlex Zinenko       return nullptr;
1160519847feSAlex Zinenko     return StridedLayoutAttr::get(getContext(), /*offset=*/0, strides);
1161519847feSAlex Zinenko   }
1162519847feSAlex Zinenko 
1163519847feSAlex Zinenko   if (failed(parseToken(Token::comma, "expected ','")) ||
1164519847feSAlex Zinenko       failed(parseToken(Token::kw_offset, "expected 'offset' after comma")) ||
1165519847feSAlex Zinenko       failed(parseToken(Token::colon, "expected ':' after 'offset'")))
1166519847feSAlex Zinenko     return nullptr;
1167519847feSAlex Zinenko 
11680a81ace0SKazu Hirata   std::optional<int64_t> offset = parseStrideOrOffset();
1169519847feSAlex Zinenko   if (!offset || failed(parseToken(Token::greater, "expected '>'")))
1170519847feSAlex Zinenko     return nullptr;
1171519847feSAlex Zinenko 
1172519847feSAlex Zinenko   if (failed(StridedLayoutAttr::verify(errorEmitter, *offset, strides)))
1173519847feSAlex Zinenko     return nullptr;
1174519847feSAlex Zinenko   return StridedLayoutAttr::get(getContext(), *offset, strides);
1175519847feSAlex Zinenko   // return getChecked<StridedLayoutAttr>(loc,getContext(), *offset, strides);
1176519847feSAlex Zinenko }
1177728a8d5aSTobias Gysi 
1178728a8d5aSTobias Gysi /// Parse a distinct attribute.
1179728a8d5aSTobias Gysi ///
1180728a8d5aSTobias Gysi ///  distinct-attribute ::= `distinct`
1181728a8d5aSTobias Gysi ///                         `[` integer-literal `]<` attribute-value `>`
1182728a8d5aSTobias Gysi ///
1183728a8d5aSTobias Gysi Attribute Parser::parseDistinctAttr(Type type) {
118472307960SJacques Pienaar   SMLoc loc = getToken().getLoc();
1185728a8d5aSTobias Gysi   consumeToken(Token::kw_distinct);
1186728a8d5aSTobias Gysi   if (parseToken(Token::l_square, "expected '[' after 'distinct'"))
1187728a8d5aSTobias Gysi     return {};
1188728a8d5aSTobias Gysi 
1189728a8d5aSTobias Gysi   // Parse the distinct integer identifier.
1190728a8d5aSTobias Gysi   Token token = getToken();
1191728a8d5aSTobias Gysi   if (parseToken(Token::integer, "expected distinct ID"))
1192728a8d5aSTobias Gysi     return {};
1193728a8d5aSTobias Gysi   std::optional<uint64_t> value = token.getUInt64IntegerValue();
1194728a8d5aSTobias Gysi   if (!value) {
1195728a8d5aSTobias Gysi     emitError("expected an unsigned 64-bit integer");
1196728a8d5aSTobias Gysi     return {};
1197728a8d5aSTobias Gysi   }
1198728a8d5aSTobias Gysi 
1199728a8d5aSTobias Gysi   // Parse the referenced attribute.
1200728a8d5aSTobias Gysi   if (parseToken(Token::r_square, "expected ']' to close distinct ID") ||
1201728a8d5aSTobias Gysi       parseToken(Token::less, "expected '<' after distinct ID"))
1202728a8d5aSTobias Gysi     return {};
1203629460a9SMarkus Böck 
1204629460a9SMarkus Böck   Attribute referencedAttr;
1205629460a9SMarkus Böck   if (getToken().is(Token::greater)) {
1206629460a9SMarkus Böck     consumeToken();
1207629460a9SMarkus Böck     referencedAttr = builder.getUnitAttr();
1208629460a9SMarkus Böck   } else {
1209629460a9SMarkus Böck     referencedAttr = parseAttribute(type);
1210728a8d5aSTobias Gysi     if (!referencedAttr) {
1211728a8d5aSTobias Gysi       emitError("expected attribute");
1212728a8d5aSTobias Gysi       return {};
1213728a8d5aSTobias Gysi     }
1214728a8d5aSTobias Gysi 
1215629460a9SMarkus Böck     if (parseToken(Token::greater, "expected '>' to close distinct attribute"))
1216629460a9SMarkus Böck       return {};
1217629460a9SMarkus Böck   }
1218629460a9SMarkus Böck 
1219728a8d5aSTobias Gysi   // Add the distinct attribute to the parser state, if it has not been parsed
1220728a8d5aSTobias Gysi   // before. Otherwise, check if the parsed reference attribute matches the one
1221728a8d5aSTobias Gysi   // found in the parser state.
1222728a8d5aSTobias Gysi   DenseMap<uint64_t, DistinctAttr> &distinctAttrs =
1223728a8d5aSTobias Gysi       state.symbols.distinctAttributes;
1224728a8d5aSTobias Gysi   auto it = distinctAttrs.find(*value);
1225728a8d5aSTobias Gysi   if (it == distinctAttrs.end()) {
1226728a8d5aSTobias Gysi     DistinctAttr distinctAttr = DistinctAttr::create(referencedAttr);
1227728a8d5aSTobias Gysi     it = distinctAttrs.try_emplace(*value, distinctAttr).first;
1228728a8d5aSTobias Gysi   } else if (it->getSecond().getReferencedAttr() != referencedAttr) {
122972307960SJacques Pienaar     emitError(loc, "referenced attribute does not match previous definition: ")
1230728a8d5aSTobias Gysi         << it->getSecond().getReferencedAttr();
1231728a8d5aSTobias Gysi     return {};
1232728a8d5aSTobias Gysi   }
1233728a8d5aSTobias Gysi 
1234728a8d5aSTobias Gysi   return it->getSecond();
1235728a8d5aSTobias Gysi }
1236