xref: /llvm-project/mlir/lib/AsmParser/TypeParser.cpp (revision f4943464d769e2eacd5c54dfaaf0468788abeb84)
1 //===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the MLIR Types.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Parser.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BuiltinAttributeInterfaces.h"
16 #include "mlir/IR/BuiltinTypeInterfaces.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/IR/TensorEncoding.h"
20 #include "mlir/IR/Types.h"
21 #include "mlir/Support/LLVM.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include <cassert>
24 #include <cstdint>
25 #include <limits>
26 #include <optional>
27 
28 using namespace mlir;
29 using namespace mlir::detail;
30 
31 /// Optionally parse a type.
32 OptionalParseResult Parser::parseOptionalType(Type &type) {
33   // There are many different starting tokens for a type, check them here.
34   switch (getToken().getKind()) {
35   case Token::l_paren:
36   case Token::kw_memref:
37   case Token::kw_tensor:
38   case Token::kw_complex:
39   case Token::kw_tuple:
40   case Token::kw_vector:
41   case Token::inttype:
42   case Token::kw_f4E2M1FN:
43   case Token::kw_f6E2M3FN:
44   case Token::kw_f6E3M2FN:
45   case Token::kw_f8E5M2:
46   case Token::kw_f8E4M3:
47   case Token::kw_f8E4M3FN:
48   case Token::kw_f8E5M2FNUZ:
49   case Token::kw_f8E4M3FNUZ:
50   case Token::kw_f8E4M3B11FNUZ:
51   case Token::kw_f8E3M4:
52   case Token::kw_f8E8M0FNU:
53   case Token::kw_bf16:
54   case Token::kw_f16:
55   case Token::kw_tf32:
56   case Token::kw_f32:
57   case Token::kw_f64:
58   case Token::kw_f80:
59   case Token::kw_f128:
60   case Token::kw_index:
61   case Token::kw_none:
62   case Token::exclamation_identifier:
63     return failure(!(type = parseType()));
64 
65   default:
66     return std::nullopt;
67   }
68 }
69 
70 /// Parse an arbitrary type.
71 ///
72 ///   type ::= function-type
73 ///          | non-function-type
74 ///
75 Type Parser::parseType() {
76   if (getToken().is(Token::l_paren))
77     return parseFunctionType();
78   return parseNonFunctionType();
79 }
80 
81 /// Parse a function result type.
82 ///
83 ///   function-result-type ::= type-list-parens
84 ///                          | non-function-type
85 ///
86 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
87   if (getToken().is(Token::l_paren))
88     return parseTypeListParens(elements);
89 
90   Type t = parseNonFunctionType();
91   if (!t)
92     return failure();
93   elements.push_back(t);
94   return success();
95 }
96 
97 /// Parse a list of types without an enclosing parenthesis.  The list must have
98 /// at least one member.
99 ///
100 ///   type-list-no-parens ::=  type (`,` type)*
101 ///
102 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
103   auto parseElt = [&]() -> ParseResult {
104     auto elt = parseType();
105     elements.push_back(elt);
106     return elt ? success() : failure();
107   };
108 
109   return parseCommaSeparatedList(parseElt);
110 }
111 
112 /// Parse a parenthesized list of types.
113 ///
114 ///   type-list-parens ::= `(` `)`
115 ///                      | `(` type-list-no-parens `)`
116 ///
117 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
118   if (parseToken(Token::l_paren, "expected '('"))
119     return failure();
120 
121   // Handle empty lists.
122   if (getToken().is(Token::r_paren))
123     return consumeToken(), success();
124 
125   if (parseTypeListNoParens(elements) ||
126       parseToken(Token::r_paren, "expected ')'"))
127     return failure();
128   return success();
129 }
130 
131 /// Parse a complex type.
132 ///
133 ///   complex-type ::= `complex` `<` type `>`
134 ///
135 Type Parser::parseComplexType() {
136   consumeToken(Token::kw_complex);
137 
138   // Parse the '<'.
139   if (parseToken(Token::less, "expected '<' in complex type"))
140     return nullptr;
141 
142   SMLoc elementTypeLoc = getToken().getLoc();
143   auto elementType = parseType();
144   if (!elementType ||
145       parseToken(Token::greater, "expected '>' in complex type"))
146     return nullptr;
147   if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
148     return emitError(elementTypeLoc, "invalid element type for complex"),
149            nullptr;
150 
151   return ComplexType::get(elementType);
152 }
153 
154 /// Parse a function type.
155 ///
156 ///   function-type ::= type-list-parens `->` function-result-type
157 ///
158 Type Parser::parseFunctionType() {
159   assert(getToken().is(Token::l_paren));
160 
161   SmallVector<Type, 4> arguments, results;
162   if (parseTypeListParens(arguments) ||
163       parseToken(Token::arrow, "expected '->' in function type") ||
164       parseFunctionResultTypes(results))
165     return nullptr;
166 
167   return builder.getFunctionType(arguments, results);
168 }
169 
170 /// Parse a memref type.
171 ///
172 ///   memref-type ::= ranked-memref-type | unranked-memref-type
173 ///
174 ///   ranked-memref-type ::= `memref` `<` dimension-list-ranked type
175 ///                          (`,` layout-specification)? (`,` memory-space)? `>`
176 ///
177 ///   unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
178 ///
179 ///   stride-list ::= `[` (dimension (`,` dimension)*)? `]`
180 ///   strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
181 ///   layout-specification ::= semi-affine-map | strided-layout | attribute
182 ///   memory-space ::= integer-literal | attribute
183 ///
184 Type Parser::parseMemRefType() {
185   SMLoc loc = getToken().getLoc();
186   consumeToken(Token::kw_memref);
187 
188   if (parseToken(Token::less, "expected '<' in memref type"))
189     return nullptr;
190 
191   bool isUnranked;
192   SmallVector<int64_t, 4> dimensions;
193 
194   if (consumeIf(Token::star)) {
195     // This is an unranked memref type.
196     isUnranked = true;
197     if (parseXInDimensionList())
198       return nullptr;
199 
200   } else {
201     isUnranked = false;
202     if (parseDimensionListRanked(dimensions))
203       return nullptr;
204   }
205 
206   // Parse the element type.
207   auto typeLoc = getToken().getLoc();
208   auto elementType = parseType();
209   if (!elementType)
210     return nullptr;
211 
212   // Check that memref is formed from allowed types.
213   if (!BaseMemRefType::isValidElementType(elementType))
214     return emitError(typeLoc, "invalid memref element type"), nullptr;
215 
216   MemRefLayoutAttrInterface layout;
217   Attribute memorySpace;
218 
219   auto parseElt = [&]() -> ParseResult {
220     // Either it is MemRefLayoutAttrInterface or memory space attribute.
221     Attribute attr = parseAttribute();
222     if (!attr)
223       return failure();
224 
225     if (isa<MemRefLayoutAttrInterface>(attr)) {
226       layout = cast<MemRefLayoutAttrInterface>(attr);
227     } else if (memorySpace) {
228       return emitError("multiple memory spaces specified in memref type");
229     } else {
230       memorySpace = attr;
231       return success();
232     }
233 
234     if (isUnranked)
235       return emitError("cannot have affine map for unranked memref type");
236     if (memorySpace)
237       return emitError("expected memory space to be last in memref type");
238 
239     return success();
240   };
241 
242   // Parse a list of mappings and address space if present.
243   if (!consumeIf(Token::greater)) {
244     // Parse comma separated list of affine maps, followed by memory space.
245     if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
246         parseCommaSeparatedListUntil(Token::greater, parseElt,
247                                      /*allowEmptyList=*/false)) {
248       return nullptr;
249     }
250   }
251 
252   if (isUnranked)
253     return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
254 
255   return getChecked<MemRefType>(loc, dimensions, elementType, layout,
256                                 memorySpace);
257 }
258 
259 /// Parse any type except the function type.
260 ///
261 ///   non-function-type ::= integer-type
262 ///                       | index-type
263 ///                       | float-type
264 ///                       | extended-type
265 ///                       | vector-type
266 ///                       | tensor-type
267 ///                       | memref-type
268 ///                       | complex-type
269 ///                       | tuple-type
270 ///                       | none-type
271 ///
272 ///   index-type ::= `index`
273 ///   float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128`
274 ///   none-type ::= `none`
275 ///
276 Type Parser::parseNonFunctionType() {
277   switch (getToken().getKind()) {
278   default:
279     return (emitWrongTokenError("expected non-function type"), nullptr);
280   case Token::kw_memref:
281     return parseMemRefType();
282   case Token::kw_tensor:
283     return parseTensorType();
284   case Token::kw_complex:
285     return parseComplexType();
286   case Token::kw_tuple:
287     return parseTupleType();
288   case Token::kw_vector:
289     return parseVectorType();
290   // integer-type
291   case Token::inttype: {
292     auto width = getToken().getIntTypeBitwidth();
293     if (!width.has_value())
294       return (emitError("invalid integer width"), nullptr);
295     if (*width > IntegerType::kMaxWidth) {
296       emitError(getToken().getLoc(), "integer bitwidth is limited to ")
297           << IntegerType::kMaxWidth << " bits";
298       return nullptr;
299     }
300 
301     IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
302     if (std::optional<bool> signedness = getToken().getIntTypeSignedness())
303       signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
304 
305     consumeToken(Token::inttype);
306     return IntegerType::get(getContext(), *width, signSemantics);
307   }
308 
309   // float-type
310   case Token::kw_f4E2M1FN:
311     consumeToken(Token::kw_f4E2M1FN);
312     return builder.getType<Float4E2M1FNType>();
313   case Token::kw_f6E2M3FN:
314     consumeToken(Token::kw_f6E2M3FN);
315     return builder.getType<Float6E2M3FNType>();
316   case Token::kw_f6E3M2FN:
317     consumeToken(Token::kw_f6E3M2FN);
318     return builder.getType<Float6E3M2FNType>();
319   case Token::kw_f8E5M2:
320     consumeToken(Token::kw_f8E5M2);
321     return builder.getType<Float8E5M2Type>();
322   case Token::kw_f8E4M3:
323     consumeToken(Token::kw_f8E4M3);
324     return builder.getType<Float8E4M3Type>();
325   case Token::kw_f8E4M3FN:
326     consumeToken(Token::kw_f8E4M3FN);
327     return builder.getType<Float8E4M3FNType>();
328   case Token::kw_f8E5M2FNUZ:
329     consumeToken(Token::kw_f8E5M2FNUZ);
330     return builder.getType<Float8E5M2FNUZType>();
331   case Token::kw_f8E4M3FNUZ:
332     consumeToken(Token::kw_f8E4M3FNUZ);
333     return builder.getType<Float8E4M3FNUZType>();
334   case Token::kw_f8E4M3B11FNUZ:
335     consumeToken(Token::kw_f8E4M3B11FNUZ);
336     return builder.getType<Float8E4M3B11FNUZType>();
337   case Token::kw_f8E3M4:
338     consumeToken(Token::kw_f8E3M4);
339     return builder.getType<Float8E3M4Type>();
340   case Token::kw_f8E8M0FNU:
341     consumeToken(Token::kw_f8E8M0FNU);
342     return builder.getType<Float8E8M0FNUType>();
343   case Token::kw_bf16:
344     consumeToken(Token::kw_bf16);
345     return builder.getType<BFloat16Type>();
346   case Token::kw_f16:
347     consumeToken(Token::kw_f16);
348     return builder.getType<Float16Type>();
349   case Token::kw_tf32:
350     consumeToken(Token::kw_tf32);
351     return builder.getType<FloatTF32Type>();
352   case Token::kw_f32:
353     consumeToken(Token::kw_f32);
354     return builder.getType<Float32Type>();
355   case Token::kw_f64:
356     consumeToken(Token::kw_f64);
357     return builder.getType<Float64Type>();
358   case Token::kw_f80:
359     consumeToken(Token::kw_f80);
360     return builder.getType<Float80Type>();
361   case Token::kw_f128:
362     consumeToken(Token::kw_f128);
363     return builder.getType<Float128Type>();
364 
365   // index-type
366   case Token::kw_index:
367     consumeToken(Token::kw_index);
368     return builder.getIndexType();
369 
370   // none-type
371   case Token::kw_none:
372     consumeToken(Token::kw_none);
373     return builder.getNoneType();
374 
375   // extended type
376   case Token::exclamation_identifier:
377     return parseExtendedType();
378 
379   // Handle completion of a dialect type.
380   case Token::code_complete:
381     if (getToken().isCodeCompletionFor(Token::exclamation_identifier))
382       return parseExtendedType();
383     return codeCompleteType();
384   }
385 }
386 
387 /// Parse a tensor type.
388 ///
389 ///   tensor-type ::= `tensor` `<` dimension-list type `>`
390 ///   dimension-list ::= dimension-list-ranked | `*x`
391 ///
392 Type Parser::parseTensorType() {
393   consumeToken(Token::kw_tensor);
394 
395   if (parseToken(Token::less, "expected '<' in tensor type"))
396     return nullptr;
397 
398   bool isUnranked;
399   SmallVector<int64_t, 4> dimensions;
400 
401   if (consumeIf(Token::star)) {
402     // This is an unranked tensor type.
403     isUnranked = true;
404 
405     if (parseXInDimensionList())
406       return nullptr;
407 
408   } else {
409     isUnranked = false;
410     if (parseDimensionListRanked(dimensions))
411       return nullptr;
412   }
413 
414   // Parse the element type.
415   auto elementTypeLoc = getToken().getLoc();
416   auto elementType = parseType();
417 
418   // Parse an optional encoding attribute.
419   Attribute encoding;
420   if (consumeIf(Token::comma)) {
421     auto parseResult = parseOptionalAttribute(encoding);
422     if (parseResult.has_value()) {
423       if (failed(parseResult.value()))
424         return nullptr;
425       if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
426         if (failed(v.verifyEncoding(dimensions, elementType,
427                                     [&] { return emitError(); })))
428           return nullptr;
429       }
430     }
431   }
432 
433   if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
434     return nullptr;
435   if (!TensorType::isValidElementType(elementType))
436     return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
437 
438   if (isUnranked) {
439     if (encoding)
440       return emitError("cannot apply encoding to unranked tensor"), nullptr;
441     return UnrankedTensorType::get(elementType);
442   }
443   return RankedTensorType::get(dimensions, elementType, encoding);
444 }
445 
446 /// Parse a tuple type.
447 ///
448 ///   tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
449 ///
450 Type Parser::parseTupleType() {
451   consumeToken(Token::kw_tuple);
452 
453   // Parse the '<'.
454   if (parseToken(Token::less, "expected '<' in tuple type"))
455     return nullptr;
456 
457   // Check for an empty tuple by directly parsing '>'.
458   if (consumeIf(Token::greater))
459     return TupleType::get(getContext());
460 
461   // Parse the element types and the '>'.
462   SmallVector<Type, 4> types;
463   if (parseTypeListNoParens(types) ||
464       parseToken(Token::greater, "expected '>' in tuple type"))
465     return nullptr;
466 
467   return TupleType::get(getContext(), types);
468 }
469 
470 /// Parse a vector type.
471 ///
472 /// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
473 /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
474 /// static-dim-list ::= decimal-literal (`x` decimal-literal)*
475 ///
476 VectorType Parser::parseVectorType() {
477   SMLoc loc = getToken().getLoc();
478   consumeToken(Token::kw_vector);
479 
480   if (parseToken(Token::less, "expected '<' in vector type"))
481     return nullptr;
482 
483   // Parse the dimensions.
484   SmallVector<int64_t, 4> dimensions;
485   SmallVector<bool, 4> scalableDims;
486   if (parseVectorDimensionList(dimensions, scalableDims))
487     return nullptr;
488 
489   // Parse the element type.
490   auto elementType = parseType();
491   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
492     return nullptr;
493 
494   return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
495 }
496 
497 /// Parse a dimension list in a vector type. This populates the dimension list.
498 /// For i-th dimension, `scalableDims[i]` contains either:
499 ///   * `false` for a non-scalable dimension (e.g. `4`),
500 ///   * `true` for a scalable dimension (e.g. `[4]`).
501 ///
502 /// vector-dim-list := (static-dim-list `x`)?
503 /// static-dim-list ::= static-dim (`x` static-dim)*
504 /// static-dim ::= (decimal-literal | `[` decimal-literal `]`)
505 ///
506 ParseResult
507 Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
508                                  SmallVectorImpl<bool> &scalableDims) {
509   // If there is a set of fixed-length dimensions, consume it
510   while (getToken().is(Token::integer) || getToken().is(Token::l_square)) {
511     int64_t value;
512     bool scalable = consumeIf(Token::l_square);
513     if (parseIntegerInDimensionList(value))
514       return failure();
515     dimensions.push_back(value);
516     if (scalable) {
517       if (!consumeIf(Token::r_square))
518         return emitWrongTokenError("missing ']' closing scalable dimension");
519     }
520     scalableDims.push_back(scalable);
521     // Make sure we have an 'x' or something like 'xbf32'.
522     if (parseXInDimensionList())
523       return failure();
524   }
525 
526   return success();
527 }
528 
529 /// Parse a dimension list of a tensor or memref type.  This populates the
530 /// dimension list, using ShapedType::kDynamic for the `?` dimensions if
531 /// `allowDynamic` is set and errors out on `?` otherwise. Parsing the trailing
532 /// `x` is configurable.
533 ///
534 ///   dimension-list ::= eps | dimension (`x` dimension)*
535 ///   dimension-list-with-trailing-x ::= (dimension `x`)*
536 ///   dimension ::= `?` | decimal-literal
537 ///
538 /// When `allowDynamic` is not set, this is used to parse:
539 ///
540 ///   static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
541 ///   static-dimension-list-with-trailing-x ::= (dimension `x`)*
542 ParseResult
543 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
544                                  bool allowDynamic, bool withTrailingX) {
545   auto parseDim = [&]() -> LogicalResult {
546     auto loc = getToken().getLoc();
547     if (consumeIf(Token::question)) {
548       if (!allowDynamic)
549         return emitError(loc, "expected static shape");
550       dimensions.push_back(ShapedType::kDynamic);
551     } else {
552       int64_t value;
553       if (failed(parseIntegerInDimensionList(value)))
554         return failure();
555       dimensions.push_back(value);
556     }
557     return success();
558   };
559 
560   if (withTrailingX) {
561     while (getToken().isAny(Token::integer, Token::question)) {
562       if (failed(parseDim()) || failed(parseXInDimensionList()))
563         return failure();
564     }
565     return success();
566   }
567 
568   if (getToken().isAny(Token::integer, Token::question)) {
569     if (failed(parseDim()))
570       return failure();
571     while (getToken().is(Token::bare_identifier) &&
572            getTokenSpelling()[0] == 'x') {
573       if (failed(parseXInDimensionList()) || failed(parseDim()))
574         return failure();
575     }
576   }
577   return success();
578 }
579 
580 ParseResult Parser::parseIntegerInDimensionList(int64_t &value) {
581   // Hexadecimal integer literals (starting with `0x`) are not allowed in
582   // aggregate type declarations.  Therefore, `0xf32` should be processed as
583   // a sequence of separate elements `0`, `x`, `f32`.
584   if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
585     // We can get here only if the token is an integer literal.  Hexadecimal
586     // integer literals can only start with `0x` (`1x` wouldn't lex as a
587     // literal, just `1` would, at which point we don't get into this
588     // branch).
589     assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
590     value = 0;
591     state.lex.resetPointer(getTokenSpelling().data() + 1);
592     consumeToken();
593   } else {
594     // Make sure this integer value is in bound and valid.
595     std::optional<uint64_t> dimension = getToken().getUInt64IntegerValue();
596     if (!dimension ||
597         *dimension > (uint64_t)std::numeric_limits<int64_t>::max())
598       return emitError("invalid dimension");
599     value = (int64_t)*dimension;
600     consumeToken(Token::integer);
601   }
602   return success();
603 }
604 
605 /// Parse an 'x' token in a dimension list, handling the case where the x is
606 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
607 /// token.
608 ParseResult Parser::parseXInDimensionList() {
609   if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
610     return emitWrongTokenError("expected 'x' in dimension list");
611 
612   // If we had a prefix of 'x', lex the next token immediately after the 'x'.
613   if (getTokenSpelling().size() != 1)
614     state.lex.resetPointer(getTokenSpelling().data() + 1);
615 
616   // Consume the 'x'.
617   consumeToken(Token::bare_identifier);
618 
619   return success();
620 }
621