1 //===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the parser for the MLIR Types. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Parser.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/BuiltinAttributeInterfaces.h" 16 #include "mlir/IR/BuiltinTypeInterfaces.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/OpDefinition.h" 19 #include "mlir/IR/TensorEncoding.h" 20 #include "mlir/IR/Types.h" 21 #include "mlir/Support/LLVM.h" 22 #include "llvm/ADT/STLExtras.h" 23 #include <cassert> 24 #include <cstdint> 25 #include <limits> 26 #include <optional> 27 28 using namespace mlir; 29 using namespace mlir::detail; 30 31 /// Optionally parse a type. 32 OptionalParseResult Parser::parseOptionalType(Type &type) { 33 // There are many different starting tokens for a type, check them here. 34 switch (getToken().getKind()) { 35 case Token::l_paren: 36 case Token::kw_memref: 37 case Token::kw_tensor: 38 case Token::kw_complex: 39 case Token::kw_tuple: 40 case Token::kw_vector: 41 case Token::inttype: 42 case Token::kw_f4E2M1FN: 43 case Token::kw_f6E2M3FN: 44 case Token::kw_f6E3M2FN: 45 case Token::kw_f8E5M2: 46 case Token::kw_f8E4M3: 47 case Token::kw_f8E4M3FN: 48 case Token::kw_f8E5M2FNUZ: 49 case Token::kw_f8E4M3FNUZ: 50 case Token::kw_f8E4M3B11FNUZ: 51 case Token::kw_f8E3M4: 52 case Token::kw_f8E8M0FNU: 53 case Token::kw_bf16: 54 case Token::kw_f16: 55 case Token::kw_tf32: 56 case Token::kw_f32: 57 case Token::kw_f64: 58 case Token::kw_f80: 59 case Token::kw_f128: 60 case Token::kw_index: 61 case Token::kw_none: 62 case Token::exclamation_identifier: 63 return failure(!(type = parseType())); 64 65 default: 66 return std::nullopt; 67 } 68 } 69 70 /// Parse an arbitrary type. 71 /// 72 /// type ::= function-type 73 /// | non-function-type 74 /// 75 Type Parser::parseType() { 76 if (getToken().is(Token::l_paren)) 77 return parseFunctionType(); 78 return parseNonFunctionType(); 79 } 80 81 /// Parse a function result type. 82 /// 83 /// function-result-type ::= type-list-parens 84 /// | non-function-type 85 /// 86 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) { 87 if (getToken().is(Token::l_paren)) 88 return parseTypeListParens(elements); 89 90 Type t = parseNonFunctionType(); 91 if (!t) 92 return failure(); 93 elements.push_back(t); 94 return success(); 95 } 96 97 /// Parse a list of types without an enclosing parenthesis. The list must have 98 /// at least one member. 99 /// 100 /// type-list-no-parens ::= type (`,` type)* 101 /// 102 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { 103 auto parseElt = [&]() -> ParseResult { 104 auto elt = parseType(); 105 elements.push_back(elt); 106 return elt ? success() : failure(); 107 }; 108 109 return parseCommaSeparatedList(parseElt); 110 } 111 112 /// Parse a parenthesized list of types. 113 /// 114 /// type-list-parens ::= `(` `)` 115 /// | `(` type-list-no-parens `)` 116 /// 117 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { 118 if (parseToken(Token::l_paren, "expected '('")) 119 return failure(); 120 121 // Handle empty lists. 122 if (getToken().is(Token::r_paren)) 123 return consumeToken(), success(); 124 125 if (parseTypeListNoParens(elements) || 126 parseToken(Token::r_paren, "expected ')'")) 127 return failure(); 128 return success(); 129 } 130 131 /// Parse a complex type. 132 /// 133 /// complex-type ::= `complex` `<` type `>` 134 /// 135 Type Parser::parseComplexType() { 136 consumeToken(Token::kw_complex); 137 138 // Parse the '<'. 139 if (parseToken(Token::less, "expected '<' in complex type")) 140 return nullptr; 141 142 SMLoc elementTypeLoc = getToken().getLoc(); 143 auto elementType = parseType(); 144 if (!elementType || 145 parseToken(Token::greater, "expected '>' in complex type")) 146 return nullptr; 147 if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType)) 148 return emitError(elementTypeLoc, "invalid element type for complex"), 149 nullptr; 150 151 return ComplexType::get(elementType); 152 } 153 154 /// Parse a function type. 155 /// 156 /// function-type ::= type-list-parens `->` function-result-type 157 /// 158 Type Parser::parseFunctionType() { 159 assert(getToken().is(Token::l_paren)); 160 161 SmallVector<Type, 4> arguments, results; 162 if (parseTypeListParens(arguments) || 163 parseToken(Token::arrow, "expected '->' in function type") || 164 parseFunctionResultTypes(results)) 165 return nullptr; 166 167 return builder.getFunctionType(arguments, results); 168 } 169 170 /// Parse a memref type. 171 /// 172 /// memref-type ::= ranked-memref-type | unranked-memref-type 173 /// 174 /// ranked-memref-type ::= `memref` `<` dimension-list-ranked type 175 /// (`,` layout-specification)? (`,` memory-space)? `>` 176 /// 177 /// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>` 178 /// 179 /// stride-list ::= `[` (dimension (`,` dimension)*)? `]` 180 /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list 181 /// layout-specification ::= semi-affine-map | strided-layout | attribute 182 /// memory-space ::= integer-literal | attribute 183 /// 184 Type Parser::parseMemRefType() { 185 SMLoc loc = getToken().getLoc(); 186 consumeToken(Token::kw_memref); 187 188 if (parseToken(Token::less, "expected '<' in memref type")) 189 return nullptr; 190 191 bool isUnranked; 192 SmallVector<int64_t, 4> dimensions; 193 194 if (consumeIf(Token::star)) { 195 // This is an unranked memref type. 196 isUnranked = true; 197 if (parseXInDimensionList()) 198 return nullptr; 199 200 } else { 201 isUnranked = false; 202 if (parseDimensionListRanked(dimensions)) 203 return nullptr; 204 } 205 206 // Parse the element type. 207 auto typeLoc = getToken().getLoc(); 208 auto elementType = parseType(); 209 if (!elementType) 210 return nullptr; 211 212 // Check that memref is formed from allowed types. 213 if (!BaseMemRefType::isValidElementType(elementType)) 214 return emitError(typeLoc, "invalid memref element type"), nullptr; 215 216 MemRefLayoutAttrInterface layout; 217 Attribute memorySpace; 218 219 auto parseElt = [&]() -> ParseResult { 220 // Either it is MemRefLayoutAttrInterface or memory space attribute. 221 Attribute attr = parseAttribute(); 222 if (!attr) 223 return failure(); 224 225 if (isa<MemRefLayoutAttrInterface>(attr)) { 226 layout = cast<MemRefLayoutAttrInterface>(attr); 227 } else if (memorySpace) { 228 return emitError("multiple memory spaces specified in memref type"); 229 } else { 230 memorySpace = attr; 231 return success(); 232 } 233 234 if (isUnranked) 235 return emitError("cannot have affine map for unranked memref type"); 236 if (memorySpace) 237 return emitError("expected memory space to be last in memref type"); 238 239 return success(); 240 }; 241 242 // Parse a list of mappings and address space if present. 243 if (!consumeIf(Token::greater)) { 244 // Parse comma separated list of affine maps, followed by memory space. 245 if (parseToken(Token::comma, "expected ',' or '>' in memref type") || 246 parseCommaSeparatedListUntil(Token::greater, parseElt, 247 /*allowEmptyList=*/false)) { 248 return nullptr; 249 } 250 } 251 252 if (isUnranked) 253 return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace); 254 255 return getChecked<MemRefType>(loc, dimensions, elementType, layout, 256 memorySpace); 257 } 258 259 /// Parse any type except the function type. 260 /// 261 /// non-function-type ::= integer-type 262 /// | index-type 263 /// | float-type 264 /// | extended-type 265 /// | vector-type 266 /// | tensor-type 267 /// | memref-type 268 /// | complex-type 269 /// | tuple-type 270 /// | none-type 271 /// 272 /// index-type ::= `index` 273 /// float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128` 274 /// none-type ::= `none` 275 /// 276 Type Parser::parseNonFunctionType() { 277 switch (getToken().getKind()) { 278 default: 279 return (emitWrongTokenError("expected non-function type"), nullptr); 280 case Token::kw_memref: 281 return parseMemRefType(); 282 case Token::kw_tensor: 283 return parseTensorType(); 284 case Token::kw_complex: 285 return parseComplexType(); 286 case Token::kw_tuple: 287 return parseTupleType(); 288 case Token::kw_vector: 289 return parseVectorType(); 290 // integer-type 291 case Token::inttype: { 292 auto width = getToken().getIntTypeBitwidth(); 293 if (!width.has_value()) 294 return (emitError("invalid integer width"), nullptr); 295 if (*width > IntegerType::kMaxWidth) { 296 emitError(getToken().getLoc(), "integer bitwidth is limited to ") 297 << IntegerType::kMaxWidth << " bits"; 298 return nullptr; 299 } 300 301 IntegerType::SignednessSemantics signSemantics = IntegerType::Signless; 302 if (std::optional<bool> signedness = getToken().getIntTypeSignedness()) 303 signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned; 304 305 consumeToken(Token::inttype); 306 return IntegerType::get(getContext(), *width, signSemantics); 307 } 308 309 // float-type 310 case Token::kw_f4E2M1FN: 311 consumeToken(Token::kw_f4E2M1FN); 312 return builder.getType<Float4E2M1FNType>(); 313 case Token::kw_f6E2M3FN: 314 consumeToken(Token::kw_f6E2M3FN); 315 return builder.getType<Float6E2M3FNType>(); 316 case Token::kw_f6E3M2FN: 317 consumeToken(Token::kw_f6E3M2FN); 318 return builder.getType<Float6E3M2FNType>(); 319 case Token::kw_f8E5M2: 320 consumeToken(Token::kw_f8E5M2); 321 return builder.getType<Float8E5M2Type>(); 322 case Token::kw_f8E4M3: 323 consumeToken(Token::kw_f8E4M3); 324 return builder.getType<Float8E4M3Type>(); 325 case Token::kw_f8E4M3FN: 326 consumeToken(Token::kw_f8E4M3FN); 327 return builder.getType<Float8E4M3FNType>(); 328 case Token::kw_f8E5M2FNUZ: 329 consumeToken(Token::kw_f8E5M2FNUZ); 330 return builder.getType<Float8E5M2FNUZType>(); 331 case Token::kw_f8E4M3FNUZ: 332 consumeToken(Token::kw_f8E4M3FNUZ); 333 return builder.getType<Float8E4M3FNUZType>(); 334 case Token::kw_f8E4M3B11FNUZ: 335 consumeToken(Token::kw_f8E4M3B11FNUZ); 336 return builder.getType<Float8E4M3B11FNUZType>(); 337 case Token::kw_f8E3M4: 338 consumeToken(Token::kw_f8E3M4); 339 return builder.getType<Float8E3M4Type>(); 340 case Token::kw_f8E8M0FNU: 341 consumeToken(Token::kw_f8E8M0FNU); 342 return builder.getType<Float8E8M0FNUType>(); 343 case Token::kw_bf16: 344 consumeToken(Token::kw_bf16); 345 return builder.getType<BFloat16Type>(); 346 case Token::kw_f16: 347 consumeToken(Token::kw_f16); 348 return builder.getType<Float16Type>(); 349 case Token::kw_tf32: 350 consumeToken(Token::kw_tf32); 351 return builder.getType<FloatTF32Type>(); 352 case Token::kw_f32: 353 consumeToken(Token::kw_f32); 354 return builder.getType<Float32Type>(); 355 case Token::kw_f64: 356 consumeToken(Token::kw_f64); 357 return builder.getType<Float64Type>(); 358 case Token::kw_f80: 359 consumeToken(Token::kw_f80); 360 return builder.getType<Float80Type>(); 361 case Token::kw_f128: 362 consumeToken(Token::kw_f128); 363 return builder.getType<Float128Type>(); 364 365 // index-type 366 case Token::kw_index: 367 consumeToken(Token::kw_index); 368 return builder.getIndexType(); 369 370 // none-type 371 case Token::kw_none: 372 consumeToken(Token::kw_none); 373 return builder.getNoneType(); 374 375 // extended type 376 case Token::exclamation_identifier: 377 return parseExtendedType(); 378 379 // Handle completion of a dialect type. 380 case Token::code_complete: 381 if (getToken().isCodeCompletionFor(Token::exclamation_identifier)) 382 return parseExtendedType(); 383 return codeCompleteType(); 384 } 385 } 386 387 /// Parse a tensor type. 388 /// 389 /// tensor-type ::= `tensor` `<` dimension-list type `>` 390 /// dimension-list ::= dimension-list-ranked | `*x` 391 /// 392 Type Parser::parseTensorType() { 393 consumeToken(Token::kw_tensor); 394 395 if (parseToken(Token::less, "expected '<' in tensor type")) 396 return nullptr; 397 398 bool isUnranked; 399 SmallVector<int64_t, 4> dimensions; 400 401 if (consumeIf(Token::star)) { 402 // This is an unranked tensor type. 403 isUnranked = true; 404 405 if (parseXInDimensionList()) 406 return nullptr; 407 408 } else { 409 isUnranked = false; 410 if (parseDimensionListRanked(dimensions)) 411 return nullptr; 412 } 413 414 // Parse the element type. 415 auto elementTypeLoc = getToken().getLoc(); 416 auto elementType = parseType(); 417 418 // Parse an optional encoding attribute. 419 Attribute encoding; 420 if (consumeIf(Token::comma)) { 421 auto parseResult = parseOptionalAttribute(encoding); 422 if (parseResult.has_value()) { 423 if (failed(parseResult.value())) 424 return nullptr; 425 if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) { 426 if (failed(v.verifyEncoding(dimensions, elementType, 427 [&] { return emitError(); }))) 428 return nullptr; 429 } 430 } 431 } 432 433 if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) 434 return nullptr; 435 if (!TensorType::isValidElementType(elementType)) 436 return emitError(elementTypeLoc, "invalid tensor element type"), nullptr; 437 438 if (isUnranked) { 439 if (encoding) 440 return emitError("cannot apply encoding to unranked tensor"), nullptr; 441 return UnrankedTensorType::get(elementType); 442 } 443 return RankedTensorType::get(dimensions, elementType, encoding); 444 } 445 446 /// Parse a tuple type. 447 /// 448 /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` 449 /// 450 Type Parser::parseTupleType() { 451 consumeToken(Token::kw_tuple); 452 453 // Parse the '<'. 454 if (parseToken(Token::less, "expected '<' in tuple type")) 455 return nullptr; 456 457 // Check for an empty tuple by directly parsing '>'. 458 if (consumeIf(Token::greater)) 459 return TupleType::get(getContext()); 460 461 // Parse the element types and the '>'. 462 SmallVector<Type, 4> types; 463 if (parseTypeListNoParens(types) || 464 parseToken(Token::greater, "expected '>' in tuple type")) 465 return nullptr; 466 467 return TupleType::get(getContext(), types); 468 } 469 470 /// Parse a vector type. 471 /// 472 /// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>` 473 /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)? 474 /// static-dim-list ::= decimal-literal (`x` decimal-literal)* 475 /// 476 VectorType Parser::parseVectorType() { 477 SMLoc loc = getToken().getLoc(); 478 consumeToken(Token::kw_vector); 479 480 if (parseToken(Token::less, "expected '<' in vector type")) 481 return nullptr; 482 483 // Parse the dimensions. 484 SmallVector<int64_t, 4> dimensions; 485 SmallVector<bool, 4> scalableDims; 486 if (parseVectorDimensionList(dimensions, scalableDims)) 487 return nullptr; 488 489 // Parse the element type. 490 auto elementType = parseType(); 491 if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) 492 return nullptr; 493 494 return getChecked<VectorType>(loc, dimensions, elementType, scalableDims); 495 } 496 497 /// Parse a dimension list in a vector type. This populates the dimension list. 498 /// For i-th dimension, `scalableDims[i]` contains either: 499 /// * `false` for a non-scalable dimension (e.g. `4`), 500 /// * `true` for a scalable dimension (e.g. `[4]`). 501 /// 502 /// vector-dim-list := (static-dim-list `x`)? 503 /// static-dim-list ::= static-dim (`x` static-dim)* 504 /// static-dim ::= (decimal-literal | `[` decimal-literal `]`) 505 /// 506 ParseResult 507 Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions, 508 SmallVectorImpl<bool> &scalableDims) { 509 // If there is a set of fixed-length dimensions, consume it 510 while (getToken().is(Token::integer) || getToken().is(Token::l_square)) { 511 int64_t value; 512 bool scalable = consumeIf(Token::l_square); 513 if (parseIntegerInDimensionList(value)) 514 return failure(); 515 dimensions.push_back(value); 516 if (scalable) { 517 if (!consumeIf(Token::r_square)) 518 return emitWrongTokenError("missing ']' closing scalable dimension"); 519 } 520 scalableDims.push_back(scalable); 521 // Make sure we have an 'x' or something like 'xbf32'. 522 if (parseXInDimensionList()) 523 return failure(); 524 } 525 526 return success(); 527 } 528 529 /// Parse a dimension list of a tensor or memref type. This populates the 530 /// dimension list, using ShapedType::kDynamic for the `?` dimensions if 531 /// `allowDynamic` is set and errors out on `?` otherwise. Parsing the trailing 532 /// `x` is configurable. 533 /// 534 /// dimension-list ::= eps | dimension (`x` dimension)* 535 /// dimension-list-with-trailing-x ::= (dimension `x`)* 536 /// dimension ::= `?` | decimal-literal 537 /// 538 /// When `allowDynamic` is not set, this is used to parse: 539 /// 540 /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)* 541 /// static-dimension-list-with-trailing-x ::= (dimension `x`)* 542 ParseResult 543 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, 544 bool allowDynamic, bool withTrailingX) { 545 auto parseDim = [&]() -> LogicalResult { 546 auto loc = getToken().getLoc(); 547 if (consumeIf(Token::question)) { 548 if (!allowDynamic) 549 return emitError(loc, "expected static shape"); 550 dimensions.push_back(ShapedType::kDynamic); 551 } else { 552 int64_t value; 553 if (failed(parseIntegerInDimensionList(value))) 554 return failure(); 555 dimensions.push_back(value); 556 } 557 return success(); 558 }; 559 560 if (withTrailingX) { 561 while (getToken().isAny(Token::integer, Token::question)) { 562 if (failed(parseDim()) || failed(parseXInDimensionList())) 563 return failure(); 564 } 565 return success(); 566 } 567 568 if (getToken().isAny(Token::integer, Token::question)) { 569 if (failed(parseDim())) 570 return failure(); 571 while (getToken().is(Token::bare_identifier) && 572 getTokenSpelling()[0] == 'x') { 573 if (failed(parseXInDimensionList()) || failed(parseDim())) 574 return failure(); 575 } 576 } 577 return success(); 578 } 579 580 ParseResult Parser::parseIntegerInDimensionList(int64_t &value) { 581 // Hexadecimal integer literals (starting with `0x`) are not allowed in 582 // aggregate type declarations. Therefore, `0xf32` should be processed as 583 // a sequence of separate elements `0`, `x`, `f32`. 584 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { 585 // We can get here only if the token is an integer literal. Hexadecimal 586 // integer literals can only start with `0x` (`1x` wouldn't lex as a 587 // literal, just `1` would, at which point we don't get into this 588 // branch). 589 assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); 590 value = 0; 591 state.lex.resetPointer(getTokenSpelling().data() + 1); 592 consumeToken(); 593 } else { 594 // Make sure this integer value is in bound and valid. 595 std::optional<uint64_t> dimension = getToken().getUInt64IntegerValue(); 596 if (!dimension || 597 *dimension > (uint64_t)std::numeric_limits<int64_t>::max()) 598 return emitError("invalid dimension"); 599 value = (int64_t)*dimension; 600 consumeToken(Token::integer); 601 } 602 return success(); 603 } 604 605 /// Parse an 'x' token in a dimension list, handling the case where the x is 606 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next 607 /// token. 608 ParseResult Parser::parseXInDimensionList() { 609 if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') 610 return emitWrongTokenError("expected 'x' in dimension list"); 611 612 // If we had a prefix of 'x', lex the next token immediately after the 'x'. 613 if (getTokenSpelling().size() != 1) 614 state.lex.resetPointer(getTokenSpelling().data() + 1); 615 616 // Consume the 'x'. 617 consumeToken(Token::bare_identifier); 618 619 return success(); 620 } 621