xref: /llvm-project/mlir/lib/AsmParser/AttributeParser.cpp (revision 4548bff0e8139d4f375f1078dd50a74116eae0a2)
1 //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the MLIR Types.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Parser.h"
14 
15 #include "AsmParserImpl.h"
16 #include "mlir/AsmParser/AsmParserState.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/BuiltinAttributes.h"
19 #include "mlir/IR/BuiltinDialect.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/IR/DialectResourceBlobManager.h"
23 #include "mlir/IR/IntegerSet.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/Support/Endian.h"
26 #include <optional>
27 
28 using namespace mlir;
29 using namespace mlir::detail;
30 
31 /// Parse an arbitrary attribute.
32 ///
33 ///  attribute-value ::= `unit`
34 ///                    | bool-literal
35 ///                    | integer-literal (`:` (index-type | integer-type))?
36 ///                    | float-literal (`:` float-type)?
37 ///                    | string-literal (`:` type)?
38 ///                    | type
39 ///                    | `[` `:` (integer-type | float-type) tensor-literal `]`
40 ///                    | `[` (attribute-value (`,` attribute-value)*)? `]`
41 ///                    | `{` (attribute-entry (`,` attribute-entry)*)? `}`
42 ///                    | symbol-ref-id (`::` symbol-ref-id)*
43 ///                    | `dense` `<` tensor-literal `>` `:`
44 ///                      (tensor-type | vector-type)
45 ///                    | `sparse` `<` attribute-value `,` attribute-value `>`
46 ///                      `:` (tensor-type | vector-type)
47 ///                    | `strided` `<` `[` comma-separated-int-or-question `]`
48 ///                      (`,` `offset` `:` integer-literal)? `>`
49 ///                    | distinct-attribute
50 ///                    | extended-attribute
51 ///
52 Attribute Parser::parseAttribute(Type type) {
53   switch (getToken().getKind()) {
54   // Parse an AffineMap or IntegerSet attribute.
55   case Token::kw_affine_map: {
56     consumeToken(Token::kw_affine_map);
57 
58     AffineMap map;
59     if (parseToken(Token::less, "expected '<' in affine map") ||
60         parseAffineMapReference(map) ||
61         parseToken(Token::greater, "expected '>' in affine map"))
62       return Attribute();
63     return AffineMapAttr::get(map);
64   }
65   case Token::kw_affine_set: {
66     consumeToken(Token::kw_affine_set);
67 
68     IntegerSet set;
69     if (parseToken(Token::less, "expected '<' in integer set") ||
70         parseIntegerSetReference(set) ||
71         parseToken(Token::greater, "expected '>' in integer set"))
72       return Attribute();
73     return IntegerSetAttr::get(set);
74   }
75 
76   // Parse an array attribute.
77   case Token::l_square: {
78     consumeToken(Token::l_square);
79     SmallVector<Attribute, 4> elements;
80     auto parseElt = [&]() -> ParseResult {
81       elements.push_back(parseAttribute());
82       return elements.back() ? success() : failure();
83     };
84 
85     if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
86       return nullptr;
87     return builder.getArrayAttr(elements);
88   }
89 
90   // Parse a boolean attribute.
91   case Token::kw_false:
92     consumeToken(Token::kw_false);
93     return builder.getBoolAttr(false);
94   case Token::kw_true:
95     consumeToken(Token::kw_true);
96     return builder.getBoolAttr(true);
97 
98   // Parse a dense elements attribute.
99   case Token::kw_dense:
100     return parseDenseElementsAttr(type);
101 
102   // Parse a dense resource elements attribute.
103   case Token::kw_dense_resource:
104     return parseDenseResourceElementsAttr(type);
105 
106   // Parse a dense array attribute.
107   case Token::kw_array:
108     return parseDenseArrayAttr(type);
109 
110   // Parse a dictionary attribute.
111   case Token::l_brace: {
112     NamedAttrList elements;
113     if (parseAttributeDict(elements))
114       return nullptr;
115     return elements.getDictionary(getContext());
116   }
117 
118   // Parse an extended attribute, i.e. alias or dialect attribute.
119   case Token::hash_identifier:
120     return parseExtendedAttr(type);
121 
122   // Parse floating point and integer attributes.
123   case Token::floatliteral:
124     return parseFloatAttr(type, /*isNegative=*/false);
125   case Token::integer:
126     return parseDecOrHexAttr(type, /*isNegative=*/false);
127   case Token::minus: {
128     consumeToken(Token::minus);
129     if (getToken().is(Token::integer))
130       return parseDecOrHexAttr(type, /*isNegative=*/true);
131     if (getToken().is(Token::floatliteral))
132       return parseFloatAttr(type, /*isNegative=*/true);
133 
134     return (emitWrongTokenError(
135                 "expected constant integer or floating point value"),
136             nullptr);
137   }
138 
139   // Parse a location attribute.
140   case Token::kw_loc: {
141     consumeToken(Token::kw_loc);
142 
143     LocationAttr locAttr;
144     if (parseToken(Token::l_paren, "expected '(' in inline location") ||
145         parseLocationInstance(locAttr) ||
146         parseToken(Token::r_paren, "expected ')' in inline location"))
147       return Attribute();
148     return locAttr;
149   }
150 
151   // Parse a sparse elements attribute.
152   case Token::kw_sparse:
153     return parseSparseElementsAttr(type);
154 
155   // Parse a strided layout attribute.
156   case Token::kw_strided:
157     return parseStridedLayoutAttr();
158 
159   // Parse a distinct attribute.
160   case Token::kw_distinct:
161     return parseDistinctAttr(type);
162 
163   // Parse a string attribute.
164   case Token::string: {
165     auto val = getToken().getStringValue();
166     consumeToken(Token::string);
167     // Parse the optional trailing colon type if one wasn't explicitly provided.
168     if (!type && consumeIf(Token::colon) && !(type = parseType()))
169       return Attribute();
170 
171     return type ? StringAttr::get(val, type)
172                 : StringAttr::get(getContext(), val);
173   }
174 
175   // Parse a symbol reference attribute.
176   case Token::at_identifier: {
177     // When populating the parser state, this is a list of locations for all of
178     // the nested references.
179     SmallVector<SMRange> referenceLocations;
180     if (state.asmState)
181       referenceLocations.push_back(getToken().getLocRange());
182 
183     // Parse the top-level reference.
184     std::string nameStr = getToken().getSymbolReference();
185     consumeToken(Token::at_identifier);
186 
187     // Parse any nested references.
188     std::vector<FlatSymbolRefAttr> nestedRefs;
189     while (getToken().is(Token::colon)) {
190       // Check for the '::' prefix.
191       const char *curPointer = getToken().getLoc().getPointer();
192       consumeToken(Token::colon);
193       if (!consumeIf(Token::colon)) {
194         if (getToken().isNot(Token::eof, Token::error)) {
195           state.lex.resetPointer(curPointer);
196           consumeToken();
197         }
198         break;
199       }
200       // Parse the reference itself.
201       auto curLoc = getToken().getLoc();
202       if (getToken().isNot(Token::at_identifier)) {
203         emitError(curLoc, "expected nested symbol reference identifier");
204         return Attribute();
205       }
206 
207       // If we are populating the assembly state, add the location for this
208       // reference.
209       if (state.asmState)
210         referenceLocations.push_back(getToken().getLocRange());
211 
212       std::string nameStr = getToken().getSymbolReference();
213       consumeToken(Token::at_identifier);
214       nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
215     }
216     SymbolRefAttr symbolRefAttr =
217         SymbolRefAttr::get(getContext(), nameStr, nestedRefs);
218 
219     // If we are populating the assembly state, record this symbol reference.
220     if (state.asmState)
221       state.asmState->addUses(symbolRefAttr, referenceLocations);
222     return symbolRefAttr;
223   }
224 
225   // Parse a 'unit' attribute.
226   case Token::kw_unit:
227     consumeToken(Token::kw_unit);
228     return builder.getUnitAttr();
229 
230     // Handle completion of an attribute.
231   case Token::code_complete:
232     if (getToken().isCodeCompletionFor(Token::hash_identifier))
233       return parseExtendedAttr(type);
234     return codeCompleteAttribute();
235 
236   default:
237     // Parse a type attribute. We parse `Optional` here to allow for providing a
238     // better error message.
239     Type type;
240     OptionalParseResult result = parseOptionalType(type);
241     if (!result.has_value())
242       return emitWrongTokenError("expected attribute value"), Attribute();
243     return failed(*result) ? Attribute() : TypeAttr::get(type);
244   }
245 }
246 
247 /// Parse an optional attribute with the provided type.
248 OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
249                                                    Type type) {
250   switch (getToken().getKind()) {
251   case Token::at_identifier:
252   case Token::floatliteral:
253   case Token::integer:
254   case Token::hash_identifier:
255   case Token::kw_affine_map:
256   case Token::kw_affine_set:
257   case Token::kw_dense:
258   case Token::kw_dense_resource:
259   case Token::kw_false:
260   case Token::kw_loc:
261   case Token::kw_sparse:
262   case Token::kw_true:
263   case Token::kw_unit:
264   case Token::l_brace:
265   case Token::l_square:
266   case Token::minus:
267   case Token::string:
268     attribute = parseAttribute(type);
269     return success(attribute != nullptr);
270 
271   default:
272     // Parse an optional type attribute.
273     Type type;
274     OptionalParseResult result = parseOptionalType(type);
275     if (result.has_value() && succeeded(*result))
276       attribute = TypeAttr::get(type);
277     return result;
278   }
279 }
280 OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
281                                                    Type type) {
282   return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
283 }
284 OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
285                                                    Type type) {
286   return parseOptionalAttributeWithToken(Token::string, attribute, type);
287 }
288 OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result,
289                                                    Type type) {
290   return parseOptionalAttributeWithToken(Token::at_identifier, result, type);
291 }
292 
293 /// Attribute dictionary.
294 ///
295 ///   attribute-dict ::= `{` `}`
296 ///                    | `{` attribute-entry (`,` attribute-entry)* `}`
297 ///   attribute-entry ::= (bare-id | string-literal) `=` attribute-value
298 ///
299 ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
300   llvm::SmallDenseSet<StringAttr> seenKeys;
301   auto parseElt = [&]() -> ParseResult {
302     // The name of an attribute can either be a bare identifier, or a string.
303     std::optional<StringAttr> nameId;
304     if (getToken().is(Token::string))
305       nameId = builder.getStringAttr(getToken().getStringValue());
306     else if (getToken().isAny(Token::bare_identifier, Token::inttype) ||
307              getToken().isKeyword())
308       nameId = builder.getStringAttr(getTokenSpelling());
309     else
310       return emitWrongTokenError("expected attribute name");
311 
312     if (nameId->empty())
313       return emitError("expected valid attribute name");
314 
315     if (!seenKeys.insert(*nameId).second)
316       return emitError("duplicate key '")
317              << nameId->getValue() << "' in dictionary attribute";
318     consumeToken();
319 
320     // Lazy load a dialect in the context if there is a possible namespace.
321     auto splitName = nameId->strref().split('.');
322     if (!splitName.second.empty())
323       getContext()->getOrLoadDialect(splitName.first);
324 
325     // Try to parse the '=' for the attribute value.
326     if (!consumeIf(Token::equal)) {
327       // If there is no '=', we treat this as a unit attribute.
328       attributes.push_back({*nameId, builder.getUnitAttr()});
329       return success();
330     }
331 
332     auto attr = parseAttribute();
333     if (!attr)
334       return failure();
335     attributes.push_back({*nameId, attr});
336     return success();
337   };
338 
339   return parseCommaSeparatedList(Delimiter::Braces, parseElt,
340                                  " in attribute dictionary");
341 }
342 
343 /// Parse a float attribute.
344 Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
345   auto val = getToken().getFloatingPointValue();
346   if (!val)
347     return (emitError("floating point value too large for attribute"), nullptr);
348   consumeToken(Token::floatliteral);
349   if (!type) {
350     // Default to F64 when no type is specified.
351     if (!consumeIf(Token::colon))
352       type = builder.getF64Type();
353     else if (!(type = parseType()))
354       return nullptr;
355   }
356   if (!isa<FloatType>(type))
357     return (emitError("floating point value not valid for specified type"),
358             nullptr);
359   return FloatAttr::get(type, isNegative ? -*val : *val);
360 }
361 
362 /// Construct an APint from a parsed value, a known attribute type and
363 /// sign.
364 static std::optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
365                                                 StringRef spelling) {
366   // Parse the integer value into an APInt that is big enough to hold the value.
367   APInt result;
368   bool isHex = spelling.size() > 1 && spelling[1] == 'x';
369   if (spelling.getAsInteger(isHex ? 0 : 10, result))
370     return std::nullopt;
371 
372   // Extend or truncate the bitwidth to the right size.
373   unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
374                                   : type.getIntOrFloatBitWidth();
375 
376   if (width > result.getBitWidth()) {
377     result = result.zext(width);
378   } else if (width < result.getBitWidth()) {
379     // The parser can return an unnecessarily wide result with leading zeros.
380     // This isn't a problem, but truncating off bits is bad.
381     if (result.countl_zero() < result.getBitWidth() - width)
382       return std::nullopt;
383 
384     result = result.trunc(width);
385   }
386 
387   if (width == 0) {
388     // 0 bit integers cannot be negative and manipulation of their sign bit will
389     // assert, so short-cut validation here.
390     if (isNegative)
391       return std::nullopt;
392   } else if (isNegative) {
393     // The value is negative, we have an overflow if the sign bit is not set
394     // in the negated apInt.
395     result.negate();
396     if (!result.isSignBitSet())
397       return std::nullopt;
398   } else if ((type.isSignedInteger() || type.isIndex()) &&
399              result.isSignBitSet()) {
400     // The value is a positive signed integer or index,
401     // we have an overflow if the sign bit is set.
402     return std::nullopt;
403   }
404 
405   return result;
406 }
407 
408 /// Parse a decimal or a hexadecimal literal, which can be either an integer
409 /// or a float attribute.
410 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
411   Token tok = getToken();
412   StringRef spelling = tok.getSpelling();
413   SMLoc loc = tok.getLoc();
414 
415   consumeToken(Token::integer);
416   if (!type) {
417     // Default to i64 if not type is specified.
418     if (!consumeIf(Token::colon))
419       type = builder.getIntegerType(64);
420     else if (!(type = parseType()))
421       return nullptr;
422   }
423 
424   if (auto floatType = dyn_cast<FloatType>(type)) {
425     std::optional<APFloat> result;
426     if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
427                                             floatType.getFloatSemantics())))
428       return Attribute();
429     return FloatAttr::get(floatType, *result);
430   }
431 
432   if (!isa<IntegerType, IndexType>(type))
433     return emitError(loc, "integer literal not valid for specified type"),
434            nullptr;
435 
436   if (isNegative && type.isUnsignedInteger()) {
437     emitError(loc,
438               "negative integer literal not valid for unsigned integer type");
439     return nullptr;
440   }
441 
442   std::optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
443   if (!apInt)
444     return emitError(loc, "integer constant out of range for attribute"),
445            nullptr;
446   return builder.getIntegerAttr(type, *apInt);
447 }
448 
449 //===----------------------------------------------------------------------===//
450 // TensorLiteralParser
451 //===----------------------------------------------------------------------===//
452 
453 /// Parse elements values stored within a hex string. On success, the values are
454 /// stored into 'result'.
455 static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
456                                              std::string &result) {
457   if (std::optional<std::string> value = tok.getHexStringValue()) {
458     result = std::move(*value);
459     return success();
460   }
461   return parser.emitError(
462       tok.getLoc(), "expected string containing hex digits starting with `0x`");
463 }
464 
465 namespace {
466 /// This class implements a parser for TensorLiterals. A tensor literal is
467 /// either a single element (e.g, 5) or a multi-dimensional list of elements
468 /// (e.g., [[5, 5]]).
469 class TensorLiteralParser {
470 public:
471   TensorLiteralParser(Parser &p) : p(p) {}
472 
473   /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
474   /// may also parse a tensor literal that is store as a hex string.
475   ParseResult parse(bool allowHex);
476 
477   /// Build a dense attribute instance with the parsed elements and the given
478   /// shaped type.
479   DenseElementsAttr getAttr(SMLoc loc, ShapedType type);
480 
481   ArrayRef<int64_t> getShape() const { return shape; }
482 
483 private:
484   /// Get the parsed elements for an integer attribute.
485   ParseResult getIntAttrElements(SMLoc loc, Type eltTy,
486                                  std::vector<APInt> &intValues);
487 
488   /// Get the parsed elements for a float attribute.
489   ParseResult getFloatAttrElements(SMLoc loc, FloatType eltTy,
490                                    std::vector<APFloat> &floatValues);
491 
492   /// Build a Dense String attribute for the given type.
493   DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy);
494 
495   /// Build a Dense attribute with hex data for the given type.
496   DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type);
497 
498   /// Parse a single element, returning failure if it isn't a valid element
499   /// literal. For example:
500   /// parseElement(1) -> Success, 1
501   /// parseElement([1]) -> Failure
502   ParseResult parseElement();
503 
504   /// Parse a list of either lists or elements, returning the dimensions of the
505   /// parsed sub-tensors in dims. For example:
506   ///   parseList([1, 2, 3]) -> Success, [3]
507   ///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
508   ///   parseList([[1, 2], 3]) -> Failure
509   ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
510   ParseResult parseList(SmallVectorImpl<int64_t> &dims);
511 
512   /// Parse a literal that was printed as a hex string.
513   ParseResult parseHexElements();
514 
515   Parser &p;
516 
517   /// The shape inferred from the parsed elements.
518   SmallVector<int64_t, 4> shape;
519 
520   /// Storage used when parsing elements, this is a pair of <is_negated, token>.
521   std::vector<std::pair<bool, Token>> storage;
522 
523   /// Storage used when parsing elements that were stored as hex values.
524   std::optional<Token> hexStorage;
525 };
526 } // namespace
527 
528 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
529 /// may also parse a tensor literal that is store as a hex string.
530 ParseResult TensorLiteralParser::parse(bool allowHex) {
531   // If hex is allowed, check for a string literal.
532   if (allowHex && p.getToken().is(Token::string)) {
533     hexStorage = p.getToken();
534     p.consumeToken(Token::string);
535     return success();
536   }
537   // Otherwise, parse a list or an individual element.
538   if (p.getToken().is(Token::l_square))
539     return parseList(shape);
540   return parseElement();
541 }
542 
543 /// Build a dense attribute instance with the parsed elements and the given
544 /// shaped type.
545 DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
546   Type eltType = type.getElementType();
547 
548   // Check to see if we parse the literal from a hex string.
549   if (hexStorage &&
550       (eltType.isIntOrIndexOrFloat() || isa<ComplexType>(eltType)))
551     return getHexAttr(loc, type);
552 
553   // Check that the parsed storage size has the same number of elements to the
554   // type, or is a known splat.
555   if (!shape.empty() && getShape() != type.getShape()) {
556     p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
557                      << "]) does not match type ([" << type.getShape() << "])";
558     return nullptr;
559   }
560 
561   // Handle the case where no elements were parsed.
562   if (!hexStorage && storage.empty() && type.getNumElements()) {
563     p.emitError(loc) << "parsed zero elements, but type (" << type
564                      << ") expected at least 1";
565     return nullptr;
566   }
567 
568   // Handle complex types in the specific element type cases below.
569   bool isComplex = false;
570   if (ComplexType complexTy = dyn_cast<ComplexType>(eltType)) {
571     eltType = complexTy.getElementType();
572     isComplex = true;
573   }
574 
575   // Handle integer and index types.
576   if (eltType.isIntOrIndex()) {
577     std::vector<APInt> intValues;
578     if (failed(getIntAttrElements(loc, eltType, intValues)))
579       return nullptr;
580     if (isComplex) {
581       // If this is a complex, treat the parsed values as complex values.
582       auto complexData = llvm::ArrayRef(
583           reinterpret_cast<std::complex<APInt> *>(intValues.data()),
584           intValues.size() / 2);
585       return DenseElementsAttr::get(type, complexData);
586     }
587     return DenseElementsAttr::get(type, intValues);
588   }
589   // Handle floating point types.
590   if (FloatType floatTy = dyn_cast<FloatType>(eltType)) {
591     std::vector<APFloat> floatValues;
592     if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
593       return nullptr;
594     if (isComplex) {
595       // If this is a complex, treat the parsed values as complex values.
596       auto complexData = llvm::ArrayRef(
597           reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
598           floatValues.size() / 2);
599       return DenseElementsAttr::get(type, complexData);
600     }
601     return DenseElementsAttr::get(type, floatValues);
602   }
603 
604   // Other types are assumed to be string representations.
605   return getStringAttr(loc, type, type.getElementType());
606 }
607 
608 /// Build a Dense Integer attribute for the given type.
609 ParseResult
610 TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
611                                         std::vector<APInt> &intValues) {
612   intValues.reserve(storage.size());
613   bool isUintType = eltTy.isUnsignedInteger();
614   for (const auto &signAndToken : storage) {
615     bool isNegative = signAndToken.first;
616     const Token &token = signAndToken.second;
617     auto tokenLoc = token.getLoc();
618 
619     if (isNegative && isUintType) {
620       return p.emitError(tokenLoc)
621              << "expected unsigned integer elements, but parsed negative value";
622     }
623 
624     // Check to see if floating point values were parsed.
625     if (token.is(Token::floatliteral)) {
626       return p.emitError(tokenLoc)
627              << "expected integer elements, but parsed floating-point";
628     }
629 
630     assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
631            "unexpected token type");
632     if (token.isAny(Token::kw_true, Token::kw_false)) {
633       if (!eltTy.isInteger(1)) {
634         return p.emitError(tokenLoc)
635                << "expected i1 type for 'true' or 'false' values";
636       }
637       APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
638       intValues.push_back(apInt);
639       continue;
640     }
641 
642     // Create APInt values for each element with the correct bitwidth.
643     std::optional<APInt> apInt =
644         buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
645     if (!apInt)
646       return p.emitError(tokenLoc, "integer constant out of range for type");
647     intValues.push_back(*apInt);
648   }
649   return success();
650 }
651 
652 /// Build a Dense Float attribute for the given type.
653 ParseResult
654 TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
655                                           std::vector<APFloat> &floatValues) {
656   floatValues.reserve(storage.size());
657   for (const auto &signAndToken : storage) {
658     bool isNegative = signAndToken.first;
659     const Token &token = signAndToken.second;
660     std::optional<APFloat> result;
661     if (failed(p.parseFloatFromLiteral(result, token, isNegative,
662                                        eltTy.getFloatSemantics())))
663       return failure();
664     floatValues.push_back(*result);
665   }
666   return success();
667 }
668 
669 /// Build a Dense String attribute for the given type.
670 DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type,
671                                                      Type eltTy) {
672   if (hexStorage.has_value()) {
673     auto stringValue = hexStorage->getStringValue();
674     return DenseStringElementsAttr::get(type, {stringValue});
675   }
676 
677   std::vector<std::string> stringValues;
678   std::vector<StringRef> stringRefValues;
679   stringValues.reserve(storage.size());
680   stringRefValues.reserve(storage.size());
681 
682   for (auto val : storage) {
683     stringValues.push_back(val.second.getStringValue());
684     stringRefValues.emplace_back(stringValues.back());
685   }
686 
687   return DenseStringElementsAttr::get(type, stringRefValues);
688 }
689 
690 /// Build a Dense attribute with hex data for the given type.
691 DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) {
692   Type elementType = type.getElementType();
693   if (!elementType.isIntOrIndexOrFloat() && !isa<ComplexType>(elementType)) {
694     p.emitError(loc)
695         << "expected floating-point, integer, or complex element type, got "
696         << elementType;
697     return nullptr;
698   }
699 
700   std::string data;
701   if (parseElementAttrHexValues(p, *hexStorage, data))
702     return nullptr;
703 
704   ArrayRef<char> rawData(data.data(), data.size());
705   bool detectedSplat = false;
706   if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
707     p.emitError(loc) << "elements hex data size is invalid for provided type: "
708                      << type;
709     return nullptr;
710   }
711 
712   if (llvm::endianness::native == llvm::endianness::big) {
713     // Convert endianess in big-endian(BE) machines. `rawData` is
714     // little-endian(LE) because HEX in raw data of dense element attribute
715     // is always LE format. It is converted into BE here to be used in BE
716     // machines.
717     SmallVector<char, 64> outDataVec(rawData.size());
718     MutableArrayRef<char> convRawData(outDataVec);
719     DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
720         rawData, convRawData, type);
721     return DenseElementsAttr::getFromRawBuffer(type, convRawData);
722   }
723 
724   return DenseElementsAttr::getFromRawBuffer(type, rawData);
725 }
726 
727 ParseResult TensorLiteralParser::parseElement() {
728   switch (p.getToken().getKind()) {
729   // Parse a boolean element.
730   case Token::kw_true:
731   case Token::kw_false:
732   case Token::floatliteral:
733   case Token::integer:
734     storage.emplace_back(/*isNegative=*/false, p.getToken());
735     p.consumeToken();
736     break;
737 
738   // Parse a signed integer or a negative floating-point element.
739   case Token::minus:
740     p.consumeToken(Token::minus);
741     if (!p.getToken().isAny(Token::floatliteral, Token::integer))
742       return p.emitError("expected integer or floating point literal");
743     storage.emplace_back(/*isNegative=*/true, p.getToken());
744     p.consumeToken();
745     break;
746 
747   case Token::string:
748     storage.emplace_back(/*isNegative=*/false, p.getToken());
749     p.consumeToken();
750     break;
751 
752   // Parse a complex element of the form '(' element ',' element ')'.
753   case Token::l_paren:
754     p.consumeToken(Token::l_paren);
755     if (parseElement() ||
756         p.parseToken(Token::comma, "expected ',' between complex elements") ||
757         parseElement() ||
758         p.parseToken(Token::r_paren, "expected ')' after complex elements"))
759       return failure();
760     break;
761 
762   default:
763     return p.emitError("expected element literal of primitive type");
764   }
765 
766   return success();
767 }
768 
769 /// Parse a list of either lists or elements, returning the dimensions of the
770 /// parsed sub-tensors in dims. For example:
771 ///   parseList([1, 2, 3]) -> Success, [3]
772 ///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
773 ///   parseList([[1, 2], 3]) -> Failure
774 ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
775 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
776   auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
777                        const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
778     if (prevDims == newDims)
779       return success();
780     return p.emitError("tensor literal is invalid; ranks are not consistent "
781                        "between elements");
782   };
783 
784   bool first = true;
785   SmallVector<int64_t, 4> newDims;
786   unsigned size = 0;
787   auto parseOneElement = [&]() -> ParseResult {
788     SmallVector<int64_t, 4> thisDims;
789     if (p.getToken().getKind() == Token::l_square) {
790       if (parseList(thisDims))
791         return failure();
792     } else if (parseElement()) {
793       return failure();
794     }
795     ++size;
796     if (!first)
797       return checkDims(newDims, thisDims);
798     newDims = thisDims;
799     first = false;
800     return success();
801   };
802   if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOneElement))
803     return failure();
804 
805   // Return the sublists' dimensions with 'size' prepended.
806   dims.clear();
807   dims.push_back(size);
808   dims.append(newDims.begin(), newDims.end());
809   return success();
810 }
811 
812 //===----------------------------------------------------------------------===//
813 // DenseArrayAttr Parser
814 //===----------------------------------------------------------------------===//
815 
816 namespace {
817 /// A generic dense array element parser. It parsers integer and floating point
818 /// elements.
819 class DenseArrayElementParser {
820 public:
821   explicit DenseArrayElementParser(Type type) : type(type) {}
822 
823   /// Parse an integer element.
824   ParseResult parseIntegerElement(Parser &p);
825 
826   /// Parse a floating point element.
827   ParseResult parseFloatElement(Parser &p);
828 
829   /// Convert the current contents to a dense array.
830   DenseArrayAttr getAttr() { return DenseArrayAttr::get(type, size, rawData); }
831 
832 private:
833   /// Append the raw data of an APInt to the result.
834   void append(const APInt &data);
835 
836   /// The array element type.
837   Type type;
838   /// The resultant byte array representing the contents of the array.
839   std::vector<char> rawData;
840   /// The number of elements in the array.
841   int64_t size = 0;
842 };
843 } // namespace
844 
845 void DenseArrayElementParser::append(const APInt &data) {
846   if (data.getBitWidth()) {
847     assert(data.getBitWidth() % 8 == 0);
848     unsigned byteSize = data.getBitWidth() / 8;
849     size_t offset = rawData.size();
850     rawData.insert(rawData.end(), byteSize, 0);
851     llvm::StoreIntToMemory(
852         data, reinterpret_cast<uint8_t *>(rawData.data() + offset), byteSize);
853   }
854   ++size;
855 }
856 
857 ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
858   bool isNegative = p.consumeIf(Token::minus);
859 
860   // Parse an integer literal as an APInt.
861   std::optional<APInt> value;
862   StringRef spelling = p.getToken().getSpelling();
863   if (p.getToken().isAny(Token::kw_true, Token::kw_false)) {
864     if (!type.isInteger(1))
865       return p.emitError("expected i1 type for 'true' or 'false' values");
866     value = APInt(/*numBits=*/8, p.getToken().is(Token::kw_true),
867                   !type.isUnsignedInteger());
868     p.consumeToken();
869   } else if (p.consumeIf(Token::integer)) {
870     value = buildAttributeAPInt(type, isNegative, spelling);
871     if (!value)
872       return p.emitError("integer constant out of range");
873   } else {
874     return p.emitError("expected integer literal");
875   }
876   append(*value);
877   return success();
878 }
879 
880 ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
881   bool isNegative = p.consumeIf(Token::minus);
882   Token token = p.getToken();
883   std::optional<APFloat> fromIntLit;
884   if (failed(
885           p.parseFloatFromLiteral(fromIntLit, token, isNegative,
886                                   cast<FloatType>(type).getFloatSemantics())))
887     return failure();
888   p.consumeToken();
889   append(fromIntLit->bitcastToAPInt());
890   return success();
891 }
892 
893 /// Parse a dense array attribute.
894 Attribute Parser::parseDenseArrayAttr(Type attrType) {
895   consumeToken(Token::kw_array);
896   if (parseToken(Token::less, "expected '<' after 'array'"))
897     return {};
898 
899   SMLoc typeLoc = getToken().getLoc();
900   Type eltType = parseType();
901   if (!eltType) {
902     emitError(typeLoc, "expected an integer or floating point type");
903     return {};
904   }
905 
906   // Only bool or integer and floating point elements divisible by bytes are
907   // supported.
908   if (!eltType.isIntOrIndexOrFloat()) {
909     emitError(typeLoc, "expected integer or float type, got: ") << eltType;
910     return {};
911   }
912   if (!eltType.isInteger(1) && eltType.getIntOrFloatBitWidth() % 8 != 0) {
913     emitError(typeLoc, "element type bitwidth must be a multiple of 8");
914     return {};
915   }
916 
917   // Check for empty list.
918   if (consumeIf(Token::greater))
919     return DenseArrayAttr::get(eltType, 0, {});
920 
921   if (parseToken(Token::colon, "expected ':' after dense array type"))
922     return {};
923 
924   DenseArrayElementParser eltParser(eltType);
925   if (eltType.isIntOrIndex()) {
926     if (parseCommaSeparatedList(
927             [&] { return eltParser.parseIntegerElement(*this); }))
928       return {};
929   } else {
930     if (parseCommaSeparatedList(
931             [&] { return eltParser.parseFloatElement(*this); }))
932       return {};
933   }
934   if (parseToken(Token::greater, "expected '>' to close an array attribute"))
935     return {};
936   return eltParser.getAttr();
937 }
938 
939 /// Parse a dense elements attribute.
940 Attribute Parser::parseDenseElementsAttr(Type attrType) {
941   auto attribLoc = getToken().getLoc();
942   consumeToken(Token::kw_dense);
943   if (parseToken(Token::less, "expected '<' after 'dense'"))
944     return nullptr;
945 
946   // Parse the literal data if necessary.
947   TensorLiteralParser literalParser(*this);
948   if (!consumeIf(Token::greater)) {
949     if (literalParser.parse(/*allowHex=*/true) ||
950         parseToken(Token::greater, "expected '>'"))
951       return nullptr;
952   }
953 
954   // If the type is specified `parseElementsLiteralType` will not parse a type.
955   // Use the attribute location as the location for error reporting in that
956   // case.
957   auto loc = attrType ? attribLoc : getToken().getLoc();
958   auto type = parseElementsLiteralType(attrType);
959   if (!type)
960     return nullptr;
961   return literalParser.getAttr(loc, type);
962 }
963 
964 Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
965   auto loc = getToken().getLoc();
966   consumeToken(Token::kw_dense_resource);
967   if (parseToken(Token::less, "expected '<' after 'dense_resource'"))
968     return nullptr;
969 
970   // Parse the resource handle.
971   FailureOr<AsmDialectResourceHandle> rawHandle =
972       parseResourceHandle(getContext()->getLoadedDialect<BuiltinDialect>());
973   if (failed(rawHandle) || parseToken(Token::greater, "expected '>'"))
974     return nullptr;
975 
976   auto *handle = dyn_cast<DenseResourceElementsHandle>(&*rawHandle);
977   if (!handle)
978     return emitError(loc, "invalid `dense_resource` handle type"), nullptr;
979 
980   // Parse the type of the attribute if the user didn't provide one.
981   SMLoc typeLoc = loc;
982   if (!attrType) {
983     typeLoc = getToken().getLoc();
984     if (parseToken(Token::colon, "expected ':'") || !(attrType = parseType()))
985       return nullptr;
986   }
987 
988   ShapedType shapedType = dyn_cast<ShapedType>(attrType);
989   if (!shapedType) {
990     emitError(typeLoc, "`dense_resource` expected a shaped type");
991     return nullptr;
992   }
993 
994   return DenseResourceElementsAttr::get(shapedType, *handle);
995 }
996 
997 /// Shaped type for elements attribute.
998 ///
999 ///   elements-literal-type ::= vector-type | ranked-tensor-type
1000 ///
1001 /// This method also checks the type has static shape.
1002 ShapedType Parser::parseElementsLiteralType(Type type) {
1003   // If the user didn't provide a type, parse the colon type for the literal.
1004   if (!type) {
1005     if (parseToken(Token::colon, "expected ':'"))
1006       return nullptr;
1007     if (!(type = parseType()))
1008       return nullptr;
1009   }
1010 
1011   auto sType = dyn_cast<ShapedType>(type);
1012   if (!sType) {
1013     emitError("elements literal must be a shaped type");
1014     return nullptr;
1015   }
1016 
1017   if (!sType.hasStaticShape())
1018     return (emitError("elements literal type must have static shape"), nullptr);
1019 
1020   return sType;
1021 }
1022 
1023 /// Parse a sparse elements attribute.
1024 Attribute Parser::parseSparseElementsAttr(Type attrType) {
1025   SMLoc loc = getToken().getLoc();
1026   consumeToken(Token::kw_sparse);
1027   if (parseToken(Token::less, "Expected '<' after 'sparse'"))
1028     return nullptr;
1029 
1030   // Check for the case where all elements are sparse. The indices are
1031   // represented by a 2-dimensional shape where the second dimension is the rank
1032   // of the type.
1033   Type indiceEltType = builder.getIntegerType(64);
1034   if (consumeIf(Token::greater)) {
1035     ShapedType type = parseElementsLiteralType(attrType);
1036     if (!type)
1037       return nullptr;
1038 
1039     // Construct the sparse elements attr using zero element indice/value
1040     // attributes.
1041     ShapedType indicesType =
1042         RankedTensorType::get({0, type.getRank()}, indiceEltType);
1043     ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
1044     return getChecked<SparseElementsAttr>(
1045         loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
1046         DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
1047   }
1048 
1049   /// Parse the indices. We don't allow hex values here as we may need to use
1050   /// the inferred shape.
1051   auto indicesLoc = getToken().getLoc();
1052   TensorLiteralParser indiceParser(*this);
1053   if (indiceParser.parse(/*allowHex=*/false))
1054     return nullptr;
1055 
1056   if (parseToken(Token::comma, "expected ','"))
1057     return nullptr;
1058 
1059   /// Parse the values.
1060   auto valuesLoc = getToken().getLoc();
1061   TensorLiteralParser valuesParser(*this);
1062   if (valuesParser.parse(/*allowHex=*/true))
1063     return nullptr;
1064 
1065   if (parseToken(Token::greater, "expected '>'"))
1066     return nullptr;
1067 
1068   auto type = parseElementsLiteralType(attrType);
1069   if (!type)
1070     return nullptr;
1071 
1072   // If the indices are a splat, i.e. the literal parser parsed an element and
1073   // not a list, we set the shape explicitly. The indices are represented by a
1074   // 2-dimensional shape where the second dimension is the rank of the type.
1075   // Given that the parsed indices is a splat, we know that we only have one
1076   // indice and thus one for the first dimension.
1077   ShapedType indicesType;
1078   if (indiceParser.getShape().empty()) {
1079     indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
1080   } else {
1081     // Otherwise, set the shape to the one parsed by the literal parser.
1082     indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
1083   }
1084   auto indices = indiceParser.getAttr(indicesLoc, indicesType);
1085 
1086   // If the values are a splat, set the shape explicitly based on the number of
1087   // indices. The number of indices is encoded in the first dimension of the
1088   // indice shape type.
1089   auto valuesEltType = type.getElementType();
1090   ShapedType valuesType =
1091       valuesParser.getShape().empty()
1092           ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
1093           : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
1094   auto values = valuesParser.getAttr(valuesLoc, valuesType);
1095 
1096   // Build the sparse elements attribute by the indices and values.
1097   return getChecked<SparseElementsAttr>(loc, type, indices, values);
1098 }
1099 
1100 Attribute Parser::parseStridedLayoutAttr() {
1101   // Callback for error emissing at the keyword token location.
1102   llvm::SMLoc loc = getToken().getLoc();
1103   auto errorEmitter = [&] { return emitError(loc); };
1104 
1105   consumeToken(Token::kw_strided);
1106   if (failed(parseToken(Token::less, "expected '<' after 'strided'")) ||
1107       failed(parseToken(Token::l_square, "expected '['")))
1108     return nullptr;
1109 
1110   // Parses either an integer token or a question mark token. Reports an error
1111   // and returns std::nullopt if the current token is neither. The integer token
1112   // must fit into int64_t limits.
1113   auto parseStrideOrOffset = [&]() -> std::optional<int64_t> {
1114     if (consumeIf(Token::question))
1115       return ShapedType::kDynamic;
1116 
1117     SMLoc loc = getToken().getLoc();
1118     auto emitWrongTokenError = [&] {
1119       emitError(loc, "expected a 64-bit signed integer or '?'");
1120       return std::nullopt;
1121     };
1122 
1123     bool negative = consumeIf(Token::minus);
1124 
1125     if (getToken().is(Token::integer)) {
1126       std::optional<uint64_t> value = getToken().getUInt64IntegerValue();
1127       if (!value ||
1128           *value > static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
1129         return emitWrongTokenError();
1130       consumeToken();
1131       auto result = static_cast<int64_t>(*value);
1132       if (negative)
1133         result = -result;
1134 
1135       return result;
1136     }
1137 
1138     return emitWrongTokenError();
1139   };
1140 
1141   // Parse strides.
1142   SmallVector<int64_t> strides;
1143   if (!getToken().is(Token::r_square)) {
1144     do {
1145       std::optional<int64_t> stride = parseStrideOrOffset();
1146       if (!stride)
1147         return nullptr;
1148       strides.push_back(*stride);
1149     } while (consumeIf(Token::comma));
1150   }
1151 
1152   if (failed(parseToken(Token::r_square, "expected ']'")))
1153     return nullptr;
1154 
1155   // Fast path in absence of offset.
1156   if (consumeIf(Token::greater)) {
1157     if (failed(StridedLayoutAttr::verify(errorEmitter,
1158                                          /*offset=*/0, strides)))
1159       return nullptr;
1160     return StridedLayoutAttr::get(getContext(), /*offset=*/0, strides);
1161   }
1162 
1163   if (failed(parseToken(Token::comma, "expected ','")) ||
1164       failed(parseToken(Token::kw_offset, "expected 'offset' after comma")) ||
1165       failed(parseToken(Token::colon, "expected ':' after 'offset'")))
1166     return nullptr;
1167 
1168   std::optional<int64_t> offset = parseStrideOrOffset();
1169   if (!offset || failed(parseToken(Token::greater, "expected '>'")))
1170     return nullptr;
1171 
1172   if (failed(StridedLayoutAttr::verify(errorEmitter, *offset, strides)))
1173     return nullptr;
1174   return StridedLayoutAttr::get(getContext(), *offset, strides);
1175   // return getChecked<StridedLayoutAttr>(loc,getContext(), *offset, strides);
1176 }
1177 
1178 /// Parse a distinct attribute.
1179 ///
1180 ///  distinct-attribute ::= `distinct`
1181 ///                         `[` integer-literal `]<` attribute-value `>`
1182 ///
1183 Attribute Parser::parseDistinctAttr(Type type) {
1184   SMLoc loc = getToken().getLoc();
1185   consumeToken(Token::kw_distinct);
1186   if (parseToken(Token::l_square, "expected '[' after 'distinct'"))
1187     return {};
1188 
1189   // Parse the distinct integer identifier.
1190   Token token = getToken();
1191   if (parseToken(Token::integer, "expected distinct ID"))
1192     return {};
1193   std::optional<uint64_t> value = token.getUInt64IntegerValue();
1194   if (!value) {
1195     emitError("expected an unsigned 64-bit integer");
1196     return {};
1197   }
1198 
1199   // Parse the referenced attribute.
1200   if (parseToken(Token::r_square, "expected ']' to close distinct ID") ||
1201       parseToken(Token::less, "expected '<' after distinct ID"))
1202     return {};
1203 
1204   Attribute referencedAttr;
1205   if (getToken().is(Token::greater)) {
1206     consumeToken();
1207     referencedAttr = builder.getUnitAttr();
1208   } else {
1209     referencedAttr = parseAttribute(type);
1210     if (!referencedAttr) {
1211       emitError("expected attribute");
1212       return {};
1213     }
1214 
1215     if (parseToken(Token::greater, "expected '>' to close distinct attribute"))
1216       return {};
1217   }
1218 
1219   // Add the distinct attribute to the parser state, if it has not been parsed
1220   // before. Otherwise, check if the parsed reference attribute matches the one
1221   // found in the parser state.
1222   DenseMap<uint64_t, DistinctAttr> &distinctAttrs =
1223       state.symbols.distinctAttributes;
1224   auto it = distinctAttrs.find(*value);
1225   if (it == distinctAttrs.end()) {
1226     DistinctAttr distinctAttr = DistinctAttr::create(referencedAttr);
1227     it = distinctAttrs.try_emplace(*value, distinctAttr).first;
1228   } else if (it->getSecond().getReferencedAttr() != referencedAttr) {
1229     emitError(loc, "referenced attribute does not match previous definition: ")
1230         << it->getSecond().getReferencedAttr();
1231     return {};
1232   }
1233 
1234   return it->getSecond();
1235 }
1236