1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/BuiltinAttributes.h" 14 #include "mlir/IR/BuiltinDialect.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/TensorEncoding.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "llvm/ADT/APFloat.h" 20 #include "llvm/ADT/BitVector.h" 21 #include "llvm/ADT/Sequence.h" 22 #include "llvm/ADT/Twine.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 25 using namespace mlir; 26 using namespace mlir::detail; 27 28 //===----------------------------------------------------------------------===// 29 /// Tablegen Type Definitions 30 //===----------------------------------------------------------------------===// 31 32 #define GET_TYPEDEF_CLASSES 33 #include "mlir/IR/BuiltinTypes.cpp.inc" 34 35 namespace mlir { 36 #include "mlir/IR/BuiltinTypeConstraints.cpp.inc" 37 } // namespace mlir 38 39 //===----------------------------------------------------------------------===// 40 // BuiltinDialect 41 //===----------------------------------------------------------------------===// 42 43 void BuiltinDialect::registerTypes() { 44 addTypes< 45 #define GET_TYPEDEF_LIST 46 #include "mlir/IR/BuiltinTypes.cpp.inc" 47 >(); 48 } 49 50 //===----------------------------------------------------------------------===// 51 /// ComplexType 52 //===----------------------------------------------------------------------===// 53 54 /// Verify the construction of an integer type. 55 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError, 56 Type elementType) { 57 if (!elementType.isIntOrFloat()) 58 return emitError() << "invalid element type for complex"; 59 return success(); 60 } 61 62 //===----------------------------------------------------------------------===// 63 // Integer Type 64 //===----------------------------------------------------------------------===// 65 66 /// Verify the construction of an integer type. 67 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError, 68 unsigned width, 69 SignednessSemantics signedness) { 70 if (width > IntegerType::kMaxWidth) { 71 return emitError() << "integer bitwidth is limited to " 72 << IntegerType::kMaxWidth << " bits"; 73 } 74 return success(); 75 } 76 77 unsigned IntegerType::getWidth() const { return getImpl()->width; } 78 79 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 80 return getImpl()->signedness; 81 } 82 83 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { 84 if (!scale) 85 return IntegerType(); 86 return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // Float Types 91 //===----------------------------------------------------------------------===// 92 93 // Mapping from MLIR FloatType to APFloat semantics. 94 #define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \ 95 const llvm::fltSemantics &TYPE::getFloatSemantics() const { \ 96 return APFloat::SEM(); \ 97 } 98 FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN) 99 FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN) 100 FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN) 101 FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2) 102 FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3) 103 FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN) 104 FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ) 105 FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ) 106 FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ) 107 FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4) 108 FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU) 109 FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat) 110 FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf) 111 FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32) 112 FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle) 113 FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble) 114 FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended) 115 FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad) 116 #undef FLOAT_TYPE_SEMANTICS 117 118 FloatType Float16Type::scaleElementBitwidth(unsigned scale) const { 119 if (scale == 2) 120 return Float32Type::get(getContext()); 121 if (scale == 4) 122 return Float64Type::get(getContext()); 123 return FloatType(); 124 } 125 126 FloatType BFloat16Type::scaleElementBitwidth(unsigned scale) const { 127 if (scale == 2) 128 return Float32Type::get(getContext()); 129 if (scale == 4) 130 return Float64Type::get(getContext()); 131 return FloatType(); 132 } 133 134 FloatType Float32Type::scaleElementBitwidth(unsigned scale) const { 135 if (scale == 2) 136 return Float64Type::get(getContext()); 137 return FloatType(); 138 } 139 140 //===----------------------------------------------------------------------===// 141 // FunctionType 142 //===----------------------------------------------------------------------===// 143 144 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 145 146 ArrayRef<Type> FunctionType::getInputs() const { 147 return getImpl()->getInputs(); 148 } 149 150 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 151 152 ArrayRef<Type> FunctionType::getResults() const { 153 return getImpl()->getResults(); 154 } 155 156 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const { 157 return get(getContext(), inputs, results); 158 } 159 160 /// Returns a new function type with the specified arguments and results 161 /// inserted. 162 FunctionType FunctionType::getWithArgsAndResults( 163 ArrayRef<unsigned> argIndices, TypeRange argTypes, 164 ArrayRef<unsigned> resultIndices, TypeRange resultTypes) { 165 SmallVector<Type> argStorage, resultStorage; 166 TypeRange newArgTypes = 167 insertTypesInto(getInputs(), argIndices, argTypes, argStorage); 168 TypeRange newResultTypes = 169 insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage); 170 return clone(newArgTypes, newResultTypes); 171 } 172 173 /// Returns a new function type without the specified arguments and results. 174 FunctionType 175 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices, 176 const BitVector &resultIndices) { 177 SmallVector<Type> argStorage, resultStorage; 178 TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage); 179 TypeRange newResultTypes = 180 filterTypesOut(getResults(), resultIndices, resultStorage); 181 return clone(newArgTypes, newResultTypes); 182 } 183 184 //===----------------------------------------------------------------------===// 185 // OpaqueType 186 //===----------------------------------------------------------------------===// 187 188 /// Verify the construction of an opaque type. 189 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError, 190 StringAttr dialect, StringRef typeData) { 191 if (!Dialect::isValidNamespace(dialect.strref())) 192 return emitError() << "invalid dialect namespace '" << dialect << "'"; 193 194 // Check that the dialect is actually registered. 195 MLIRContext *context = dialect.getContext(); 196 if (!context->allowsUnregisteredDialects() && 197 !context->getLoadedDialect(dialect.strref())) { 198 return emitError() 199 << "`!" << dialect << "<\"" << typeData << "\">" 200 << "` type created with unregistered dialect. If this is " 201 "intended, please call allowUnregisteredDialects() on the " 202 "MLIRContext, or use -allow-unregistered-dialect with " 203 "the MLIR opt tool used"; 204 } 205 206 return success(); 207 } 208 209 //===----------------------------------------------------------------------===// 210 // VectorType 211 //===----------------------------------------------------------------------===// 212 213 bool VectorType::isValidElementType(Type t) { 214 return isValidVectorTypeElementType(t); 215 } 216 217 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, 218 ArrayRef<int64_t> shape, Type elementType, 219 ArrayRef<bool> scalableDims) { 220 if (!isValidElementType(elementType)) 221 return emitError() 222 << "vector elements must be int/index/float type but got " 223 << elementType; 224 225 if (any_of(shape, [](int64_t i) { return i <= 0; })) 226 return emitError() 227 << "vector types must have positive constant sizes but got " 228 << shape; 229 230 if (scalableDims.size() != shape.size()) 231 return emitError() << "number of dims must match, got " 232 << scalableDims.size() << " and " << shape.size(); 233 234 return success(); 235 } 236 237 VectorType VectorType::scaleElementBitwidth(unsigned scale) { 238 if (!scale) 239 return VectorType(); 240 if (auto et = llvm::dyn_cast<IntegerType>(getElementType())) 241 if (auto scaledEt = et.scaleElementBitwidth(scale)) 242 return VectorType::get(getShape(), scaledEt, getScalableDims()); 243 if (auto et = llvm::dyn_cast<FloatType>(getElementType())) 244 if (auto scaledEt = et.scaleElementBitwidth(scale)) 245 return VectorType::get(getShape(), scaledEt, getScalableDims()); 246 return VectorType(); 247 } 248 249 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape, 250 Type elementType) const { 251 return VectorType::get(shape.value_or(getShape()), elementType, 252 getScalableDims()); 253 } 254 255 //===----------------------------------------------------------------------===// 256 // TensorType 257 //===----------------------------------------------------------------------===// 258 259 Type TensorType::getElementType() const { 260 return llvm::TypeSwitch<TensorType, Type>(*this) 261 .Case<RankedTensorType, UnrankedTensorType>( 262 [](auto type) { return type.getElementType(); }); 263 } 264 265 bool TensorType::hasRank() const { 266 return !llvm::isa<UnrankedTensorType>(*this); 267 } 268 269 ArrayRef<int64_t> TensorType::getShape() const { 270 return llvm::cast<RankedTensorType>(*this).getShape(); 271 } 272 273 TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape, 274 Type elementType) const { 275 if (llvm::dyn_cast<UnrankedTensorType>(*this)) { 276 if (shape) 277 return RankedTensorType::get(*shape, elementType); 278 return UnrankedTensorType::get(elementType); 279 } 280 281 auto rankedTy = llvm::cast<RankedTensorType>(*this); 282 if (!shape) 283 return RankedTensorType::get(rankedTy.getShape(), elementType, 284 rankedTy.getEncoding()); 285 return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType, 286 rankedTy.getEncoding()); 287 } 288 289 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape, 290 Type elementType) const { 291 return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType)); 292 } 293 294 RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const { 295 return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType())); 296 } 297 298 // Check if "elementType" can be an element type of a tensor. 299 static LogicalResult 300 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError, 301 Type elementType) { 302 if (!TensorType::isValidElementType(elementType)) 303 return emitError() << "invalid tensor element type: " << elementType; 304 return success(); 305 } 306 307 /// Return true if the specified element type is ok in a tensor. 308 bool TensorType::isValidElementType(Type type) { 309 // Note: Non standard/builtin types are allowed to exist within tensor 310 // types. Dialects are expected to verify that tensor types have a valid 311 // element type within that dialect. 312 return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 313 IndexType>(type) || 314 !llvm::isa<BuiltinDialect>(type.getDialect()); 315 } 316 317 //===----------------------------------------------------------------------===// 318 // RankedTensorType 319 //===----------------------------------------------------------------------===// 320 321 LogicalResult 322 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 323 ArrayRef<int64_t> shape, Type elementType, 324 Attribute encoding) { 325 for (int64_t s : shape) 326 if (s < 0 && !ShapedType::isDynamic(s)) 327 return emitError() << "invalid tensor dimension size"; 328 if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) 329 if (failed(v.verifyEncoding(shape, elementType, emitError))) 330 return failure(); 331 return checkTensorElementType(emitError, elementType); 332 } 333 334 //===----------------------------------------------------------------------===// 335 // UnrankedTensorType 336 //===----------------------------------------------------------------------===// 337 338 LogicalResult 339 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 340 Type elementType) { 341 return checkTensorElementType(emitError, elementType); 342 } 343 344 //===----------------------------------------------------------------------===// 345 // BaseMemRefType 346 //===----------------------------------------------------------------------===// 347 348 Type BaseMemRefType::getElementType() const { 349 return llvm::TypeSwitch<BaseMemRefType, Type>(*this) 350 .Case<MemRefType, UnrankedMemRefType>( 351 [](auto type) { return type.getElementType(); }); 352 } 353 354 bool BaseMemRefType::hasRank() const { 355 return !llvm::isa<UnrankedMemRefType>(*this); 356 } 357 358 ArrayRef<int64_t> BaseMemRefType::getShape() const { 359 return llvm::cast<MemRefType>(*this).getShape(); 360 } 361 362 BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape, 363 Type elementType) const { 364 if (llvm::dyn_cast<UnrankedMemRefType>(*this)) { 365 if (!shape) 366 return UnrankedMemRefType::get(elementType, getMemorySpace()); 367 MemRefType::Builder builder(*shape, elementType); 368 builder.setMemorySpace(getMemorySpace()); 369 return builder; 370 } 371 372 MemRefType::Builder builder(llvm::cast<MemRefType>(*this)); 373 if (shape) 374 builder.setShape(*shape); 375 builder.setElementType(elementType); 376 return builder; 377 } 378 379 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape, 380 Type elementType) const { 381 return ::llvm::cast<MemRefType>(cloneWith(shape, elementType)); 382 } 383 384 MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const { 385 return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType())); 386 } 387 388 Attribute BaseMemRefType::getMemorySpace() const { 389 if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this)) 390 return rankedMemRefTy.getMemorySpace(); 391 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace(); 392 } 393 394 unsigned BaseMemRefType::getMemorySpaceAsInt() const { 395 if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this)) 396 return rankedMemRefTy.getMemorySpaceAsInt(); 397 return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt(); 398 } 399 400 //===----------------------------------------------------------------------===// 401 // MemRefType 402 //===----------------------------------------------------------------------===// 403 404 std::optional<llvm::SmallDenseSet<unsigned>> 405 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape, 406 ArrayRef<int64_t> reducedShape, 407 bool matchDynamic) { 408 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); 409 llvm::SmallDenseSet<unsigned> unusedDims; 410 unsigned reducedIdx = 0; 411 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { 412 // Greedily insert `originalIdx` if match. 413 int64_t origSize = originalShape[originalIdx]; 414 // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1. 415 if (matchDynamic && reducedIdx < reducedRank && origSize != 1 && 416 (ShapedType::isDynamic(reducedShape[reducedIdx]) || 417 ShapedType::isDynamic(origSize))) { 418 reducedIdx++; 419 continue; 420 } 421 if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) { 422 reducedIdx++; 423 continue; 424 } 425 426 unusedDims.insert(originalIdx); 427 // If no match on `originalIdx`, the `originalShape` at this dimension 428 // must be 1, otherwise we bail. 429 if (origSize != 1) 430 return std::nullopt; 431 } 432 // The whole reducedShape must be scanned, otherwise we bail. 433 if (reducedIdx != reducedRank) 434 return std::nullopt; 435 return unusedDims; 436 } 437 438 SliceVerificationResult 439 mlir::isRankReducedType(ShapedType originalType, 440 ShapedType candidateReducedType) { 441 if (originalType == candidateReducedType) 442 return SliceVerificationResult::Success; 443 444 ShapedType originalShapedType = llvm::cast<ShapedType>(originalType); 445 ShapedType candidateReducedShapedType = 446 llvm::cast<ShapedType>(candidateReducedType); 447 448 // Rank and size logic is valid for all ShapedTypes. 449 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 450 ArrayRef<int64_t> candidateReducedShape = 451 candidateReducedShapedType.getShape(); 452 unsigned originalRank = originalShape.size(), 453 candidateReducedRank = candidateReducedShape.size(); 454 if (candidateReducedRank > originalRank) 455 return SliceVerificationResult::RankTooLarge; 456 457 auto optionalUnusedDimsMask = 458 computeRankReductionMask(originalShape, candidateReducedShape); 459 460 // Sizes cannot be matched in case empty vector is returned. 461 if (!optionalUnusedDimsMask) 462 return SliceVerificationResult::SizeMismatch; 463 464 if (originalShapedType.getElementType() != 465 candidateReducedShapedType.getElementType()) 466 return SliceVerificationResult::ElemTypeMismatch; 467 468 return SliceVerificationResult::Success; 469 } 470 471 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { 472 // Empty attribute is allowed as default memory space. 473 if (!memorySpace) 474 return true; 475 476 // Supported built-in attributes. 477 if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace)) 478 return true; 479 480 // Allow custom dialect attributes. 481 if (!isa<BuiltinDialect>(memorySpace.getDialect())) 482 return true; 483 484 return false; 485 } 486 487 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace, 488 MLIRContext *ctx) { 489 if (memorySpace == 0) 490 return nullptr; 491 492 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); 493 } 494 495 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { 496 IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace); 497 if (intMemorySpace && intMemorySpace.getValue() == 0) 498 return nullptr; 499 500 return memorySpace; 501 } 502 503 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) { 504 if (!memorySpace) 505 return 0; 506 507 assert(llvm::isa<IntegerAttr>(memorySpace) && 508 "Using `getMemorySpaceInteger` with non-Integer attribute"); 509 510 return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt()); 511 } 512 513 unsigned MemRefType::getMemorySpaceAsInt() const { 514 return detail::getMemorySpaceAsInt(getMemorySpace()); 515 } 516 517 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 518 MemRefLayoutAttrInterface layout, 519 Attribute memorySpace) { 520 // Use default layout for empty attribute. 521 if (!layout) 522 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( 523 shape.size(), elementType.getContext())); 524 525 // Drop default memory space value and replace it with empty attribute. 526 memorySpace = skipDefaultMemorySpace(memorySpace); 527 528 return Base::get(elementType.getContext(), shape, elementType, layout, 529 memorySpace); 530 } 531 532 MemRefType MemRefType::getChecked( 533 function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape, 534 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { 535 536 // Use default layout for empty attribute. 537 if (!layout) 538 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( 539 shape.size(), elementType.getContext())); 540 541 // Drop default memory space value and replace it with empty attribute. 542 memorySpace = skipDefaultMemorySpace(memorySpace); 543 544 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 545 elementType, layout, memorySpace); 546 } 547 548 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 549 AffineMap map, Attribute memorySpace) { 550 551 // Use default layout for empty map. 552 if (!map) 553 map = AffineMap::getMultiDimIdentityMap(shape.size(), 554 elementType.getContext()); 555 556 // Wrap AffineMap into Attribute. 557 auto layout = AffineMapAttr::get(map); 558 559 // Drop default memory space value and replace it with empty attribute. 560 memorySpace = skipDefaultMemorySpace(memorySpace); 561 562 return Base::get(elementType.getContext(), shape, elementType, layout, 563 memorySpace); 564 } 565 566 MemRefType 567 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, 568 ArrayRef<int64_t> shape, Type elementType, AffineMap map, 569 Attribute memorySpace) { 570 571 // Use default layout for empty map. 572 if (!map) 573 map = AffineMap::getMultiDimIdentityMap(shape.size(), 574 elementType.getContext()); 575 576 // Wrap AffineMap into Attribute. 577 auto layout = AffineMapAttr::get(map); 578 579 // Drop default memory space value and replace it with empty attribute. 580 memorySpace = skipDefaultMemorySpace(memorySpace); 581 582 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 583 elementType, layout, memorySpace); 584 } 585 586 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 587 AffineMap map, unsigned memorySpaceInd) { 588 589 // Use default layout for empty map. 590 if (!map) 591 map = AffineMap::getMultiDimIdentityMap(shape.size(), 592 elementType.getContext()); 593 594 // Wrap AffineMap into Attribute. 595 auto layout = AffineMapAttr::get(map); 596 597 // Convert deprecated integer-like memory space to Attribute. 598 Attribute memorySpace = 599 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); 600 601 return Base::get(elementType.getContext(), shape, elementType, layout, 602 memorySpace); 603 } 604 605 MemRefType 606 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, 607 ArrayRef<int64_t> shape, Type elementType, AffineMap map, 608 unsigned memorySpaceInd) { 609 610 // Use default layout for empty map. 611 if (!map) 612 map = AffineMap::getMultiDimIdentityMap(shape.size(), 613 elementType.getContext()); 614 615 // Wrap AffineMap into Attribute. 616 auto layout = AffineMapAttr::get(map); 617 618 // Convert deprecated integer-like memory space to Attribute. 619 Attribute memorySpace = 620 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); 621 622 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 623 elementType, layout, memorySpace); 624 } 625 626 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 627 ArrayRef<int64_t> shape, Type elementType, 628 MemRefLayoutAttrInterface layout, 629 Attribute memorySpace) { 630 if (!BaseMemRefType::isValidElementType(elementType)) 631 return emitError() << "invalid memref element type"; 632 633 // Negative sizes are not allowed except for `kDynamic`. 634 for (int64_t s : shape) 635 if (s < 0 && !ShapedType::isDynamic(s)) 636 return emitError() << "invalid memref size"; 637 638 assert(layout && "missing layout specification"); 639 if (failed(layout.verifyLayout(shape, emitError))) 640 return failure(); 641 642 if (!isSupportedMemorySpace(memorySpace)) 643 return emitError() << "unsupported memory space Attribute"; 644 645 return success(); 646 } 647 648 bool MemRefType::areTrailingDimsContiguous(int64_t n) { 649 if (!isLastDimUnitStride()) 650 return false; 651 652 auto memrefShape = getShape().take_back(n); 653 if (ShapedType::isDynamicShape(memrefShape)) 654 return false; 655 656 if (getLayout().isIdentity()) 657 return true; 658 659 int64_t offset; 660 SmallVector<int64_t> stridesFull; 661 if (!succeeded(getStridesAndOffset(stridesFull, offset))) 662 return false; 663 auto strides = ArrayRef<int64_t>(stridesFull).take_back(n); 664 665 if (strides.empty()) 666 return true; 667 668 // Check whether strides match "flattened" dims. 669 SmallVector<int64_t> flattenedDims; 670 auto dimProduct = 1; 671 for (auto dim : llvm::reverse(memrefShape.drop_front(1))) { 672 dimProduct *= dim; 673 flattenedDims.push_back(dimProduct); 674 } 675 676 strides = strides.drop_back(1); 677 return llvm::equal(strides, llvm::reverse(flattenedDims)); 678 } 679 680 MemRefType MemRefType::canonicalizeStridedLayout() { 681 AffineMap m = getLayout().getAffineMap(); 682 683 // Already in canonical form. 684 if (m.isIdentity()) 685 return *this; 686 687 // Can't reduce to canonical identity form, return in canonical form. 688 if (m.getNumResults() > 1) 689 return *this; 690 691 // Corner-case for 0-D affine maps. 692 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { 693 if (auto cst = llvm::dyn_cast<AffineConstantExpr>(m.getResult(0))) 694 if (cst.getValue() == 0) 695 return MemRefType::Builder(*this).setLayout({}); 696 return *this; 697 } 698 699 // 0-D corner case for empty shape that still have an affine map. Example: 700 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose 701 // offset needs to remain, just return t. 702 if (getShape().empty()) 703 return *this; 704 705 // If the canonical strided layout for the sizes of `t` is equal to the 706 // simplified layout of `t` we can just return an empty layout. Otherwise, 707 // just simplify the existing layout. 708 AffineExpr expr = makeCanonicalStridedLayoutExpr(getShape(), getContext()); 709 auto simplifiedLayoutExpr = 710 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 711 if (expr != simplifiedLayoutExpr) 712 return MemRefType::Builder(*this).setLayout( 713 AffineMapAttr::get(AffineMap::get(m.getNumDims(), m.getNumSymbols(), 714 simplifiedLayoutExpr))); 715 return MemRefType::Builder(*this).setLayout({}); 716 } 717 718 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 719 // i.e. single term). Accumulate the AffineExpr into the existing one. 720 static void extractStridesFromTerm(AffineExpr e, 721 AffineExpr multiplicativeFactor, 722 MutableArrayRef<AffineExpr> strides, 723 AffineExpr &offset) { 724 if (auto dim = dyn_cast<AffineDimExpr>(e)) 725 strides[dim.getPosition()] = 726 strides[dim.getPosition()] + multiplicativeFactor; 727 else 728 offset = offset + e * multiplicativeFactor; 729 } 730 731 /// Takes a single AffineExpr `e` and populates the `strides` array with the 732 /// strides expressions for each dim position. 733 /// The convention is that the strides for dimensions d0, .. dn appear in 734 /// order to make indexing intuitive into the result. 735 static LogicalResult extractStrides(AffineExpr e, 736 AffineExpr multiplicativeFactor, 737 MutableArrayRef<AffineExpr> strides, 738 AffineExpr &offset) { 739 auto bin = dyn_cast<AffineBinaryOpExpr>(e); 740 if (!bin) { 741 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 742 return success(); 743 } 744 745 if (bin.getKind() == AffineExprKind::CeilDiv || 746 bin.getKind() == AffineExprKind::FloorDiv || 747 bin.getKind() == AffineExprKind::Mod) 748 return failure(); 749 750 if (bin.getKind() == AffineExprKind::Mul) { 751 auto dim = dyn_cast<AffineDimExpr>(bin.getLHS()); 752 if (dim) { 753 strides[dim.getPosition()] = 754 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 755 return success(); 756 } 757 // LHS and RHS may both contain complex expressions of dims. Try one path 758 // and if it fails try the other. This is guaranteed to succeed because 759 // only one path may have a `dim`, otherwise this is not an AffineExpr in 760 // the first place. 761 if (bin.getLHS().isSymbolicOrConstant()) 762 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 763 strides, offset); 764 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 765 strides, offset); 766 } 767 768 if (bin.getKind() == AffineExprKind::Add) { 769 auto res1 = 770 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 771 auto res2 = 772 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 773 return success(succeeded(res1) && succeeded(res2)); 774 } 775 776 llvm_unreachable("unexpected binary operation"); 777 } 778 779 /// A stride specification is a list of integer values that are either static 780 /// or dynamic (encoded with ShapedType::kDynamic). Strides encode 781 /// the distance in the number of elements between successive entries along a 782 /// particular dimension. 783 /// 784 /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a 785 /// non-contiguous memory region of `42` by `16` `f32` elements in which the 786 /// distance between two consecutive elements along the outer dimension is `1` 787 /// and the distance between two consecutive elements along the inner dimension 788 /// is `64`. 789 /// 790 /// The convention is that the strides for dimensions d0, .. dn appear in 791 /// order to make indexing intuitive into the result. 792 static LogicalResult getStridesAndOffset(MemRefType t, 793 SmallVectorImpl<AffineExpr> &strides, 794 AffineExpr &offset) { 795 AffineMap m = t.getLayout().getAffineMap(); 796 797 if (m.getNumResults() != 1 && !m.isIdentity()) 798 return failure(); 799 800 auto zero = getAffineConstantExpr(0, t.getContext()); 801 auto one = getAffineConstantExpr(1, t.getContext()); 802 offset = zero; 803 strides.assign(t.getRank(), zero); 804 805 // Canonical case for empty map. 806 if (m.isIdentity()) { 807 // 0-D corner case, offset is already 0. 808 if (t.getRank() == 0) 809 return success(); 810 auto stridedExpr = 811 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 812 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 813 return success(); 814 assert(false && "unexpected failure: extract strides in canonical layout"); 815 } 816 817 // Non-canonical case requires more work. 818 auto stridedExpr = 819 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 820 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 821 offset = AffineExpr(); 822 strides.clear(); 823 return failure(); 824 } 825 826 // Simplify results to allow folding to constants and simple checks. 827 unsigned numDims = m.getNumDims(); 828 unsigned numSymbols = m.getNumSymbols(); 829 offset = simplifyAffineExpr(offset, numDims, numSymbols); 830 for (auto &stride : strides) 831 stride = simplifyAffineExpr(stride, numDims, numSymbols); 832 833 return success(); 834 } 835 836 LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides, 837 int64_t &offset) { 838 // Happy path: the type uses the strided layout directly. 839 if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(getLayout())) { 840 llvm::append_range(strides, strided.getStrides()); 841 offset = strided.getOffset(); 842 return success(); 843 } 844 845 // Otherwise, defer to the affine fallback as layouts are supposed to be 846 // convertible to affine maps. 847 AffineExpr offsetExpr; 848 SmallVector<AffineExpr, 4> strideExprs; 849 if (failed(::getStridesAndOffset(*this, strideExprs, offsetExpr))) 850 return failure(); 851 if (auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr)) 852 offset = cst.getValue(); 853 else 854 offset = ShapedType::kDynamic; 855 for (auto e : strideExprs) { 856 if (auto c = llvm::dyn_cast<AffineConstantExpr>(e)) 857 strides.push_back(c.getValue()); 858 else 859 strides.push_back(ShapedType::kDynamic); 860 } 861 return success(); 862 } 863 864 std::pair<SmallVector<int64_t>, int64_t> MemRefType::getStridesAndOffset() { 865 SmallVector<int64_t> strides; 866 int64_t offset; 867 LogicalResult status = getStridesAndOffset(strides, offset); 868 (void)status; 869 assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset"); 870 return {strides, offset}; 871 } 872 873 bool MemRefType::isStrided() { 874 int64_t offset; 875 SmallVector<int64_t, 4> strides; 876 auto res = getStridesAndOffset(strides, offset); 877 return succeeded(res); 878 } 879 880 bool MemRefType::isLastDimUnitStride() { 881 int64_t offset; 882 SmallVector<int64_t> strides; 883 auto successStrides = getStridesAndOffset(strides, offset); 884 return succeeded(successStrides) && (strides.empty() || strides.back() == 1); 885 } 886 887 //===----------------------------------------------------------------------===// 888 // UnrankedMemRefType 889 //===----------------------------------------------------------------------===// 890 891 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { 892 return detail::getMemorySpaceAsInt(getMemorySpace()); 893 } 894 895 LogicalResult 896 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 897 Type elementType, Attribute memorySpace) { 898 if (!BaseMemRefType::isValidElementType(elementType)) 899 return emitError() << "invalid memref element type"; 900 901 if (!isSupportedMemorySpace(memorySpace)) 902 return emitError() << "unsupported memory space Attribute"; 903 904 return success(); 905 } 906 907 //===----------------------------------------------------------------------===// 908 /// TupleType 909 //===----------------------------------------------------------------------===// 910 911 /// Return the elements types for this tuple. 912 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 913 914 /// Accumulate the types contained in this tuple and tuples nested within it. 915 /// Note that this only flattens nested tuples, not any other container type, 916 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 917 /// (i32, tensor<i32>, f32, i64) 918 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 919 for (Type type : getTypes()) { 920 if (auto nestedTuple = llvm::dyn_cast<TupleType>(type)) 921 nestedTuple.getFlattenedTypes(types); 922 else 923 types.push_back(type); 924 } 925 } 926 927 /// Return the number of element types. 928 size_t TupleType::size() const { return getImpl()->size(); } 929 930 //===----------------------------------------------------------------------===// 931 // Type Utilities 932 //===----------------------------------------------------------------------===// 933 934 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 935 ArrayRef<AffineExpr> exprs, 936 MLIRContext *context) { 937 // Size 0 corner case is useful for canonicalizations. 938 if (sizes.empty()) 939 return getAffineConstantExpr(0, context); 940 941 assert(!exprs.empty() && "expected exprs"); 942 auto maps = AffineMap::inferFromExprList(exprs, context); 943 assert(!maps.empty() && "Expected one non-empty map"); 944 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 945 946 AffineExpr expr; 947 bool dynamicPoisonBit = false; 948 int64_t runningSize = 1; 949 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 950 int64_t size = std::get<1>(en); 951 AffineExpr dimExpr = std::get<0>(en); 952 AffineExpr stride = dynamicPoisonBit 953 ? getAffineSymbolExpr(nSymbols++, context) 954 : getAffineConstantExpr(runningSize, context); 955 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 956 if (size > 0) { 957 runningSize *= size; 958 assert(runningSize > 0 && "integer overflow in size computation"); 959 } else { 960 dynamicPoisonBit = true; 961 } 962 } 963 return simplifyAffineExpr(expr, numDims, nSymbols); 964 } 965 966 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 967 MLIRContext *context) { 968 SmallVector<AffineExpr, 4> exprs; 969 exprs.reserve(sizes.size()); 970 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 971 exprs.push_back(getAffineDimExpr(dim, context)); 972 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 973 } 974