1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributes.h" 10 #include "AttributeDetail.h" 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/BuiltinDialect.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/DialectResourceBlobManager.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/IR/SymbolTable.h" 19 #include "mlir/IR/Types.h" 20 #include "llvm/ADT/APSInt.h" 21 #include "llvm/ADT/Sequence.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 #include "llvm/Support/Debug.h" 24 #include "llvm/Support/Endian.h" 25 #include <optional> 26 27 #define DEBUG_TYPE "builtinattributes" 28 29 using namespace mlir; 30 using namespace mlir::detail; 31 32 //===----------------------------------------------------------------------===// 33 /// Tablegen Attribute Definitions 34 //===----------------------------------------------------------------------===// 35 36 #define GET_ATTRDEF_CLASSES 37 #include "mlir/IR/BuiltinAttributes.cpp.inc" 38 39 //===----------------------------------------------------------------------===// 40 // BuiltinDialect 41 //===----------------------------------------------------------------------===// 42 43 void BuiltinDialect::registerAttributes() { 44 addAttributes< 45 #define GET_ATTRDEF_LIST 46 #include "mlir/IR/BuiltinAttributes.cpp.inc" 47 >(); 48 addAttributes<DistinctAttr>(); 49 } 50 51 //===----------------------------------------------------------------------===// 52 // DictionaryAttr 53 //===----------------------------------------------------------------------===// 54 55 /// Helper function that does either an in place sort or sorts from source array 56 /// into destination. If inPlace then storage is both the source and the 57 /// destination, else value is the source and storage destination. Returns 58 /// whether source was sorted. 59 template <bool inPlace> 60 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, 61 SmallVectorImpl<NamedAttribute> &storage) { 62 // Specialize for the common case. 63 switch (value.size()) { 64 case 0: 65 // Zero already sorted. 66 if (!inPlace) 67 storage.clear(); 68 break; 69 case 1: 70 // One already sorted but may need to be copied. 71 if (!inPlace) 72 storage.assign({value[0]}); 73 break; 74 case 2: { 75 bool isSorted = value[0] < value[1]; 76 if (inPlace) { 77 if (!isSorted) 78 std::swap(storage[0], storage[1]); 79 } else if (isSorted) { 80 storage.assign({value[0], value[1]}); 81 } else { 82 storage.assign({value[1], value[0]}); 83 } 84 return !isSorted; 85 } 86 default: 87 if (!inPlace) 88 storage.assign(value.begin(), value.end()); 89 // Check to see they are sorted already. 90 bool isSorted = llvm::is_sorted(value); 91 // If not, do a general sort. 92 if (!isSorted) 93 llvm::array_pod_sort(storage.begin(), storage.end()); 94 return !isSorted; 95 } 96 return false; 97 } 98 99 /// Returns an entry with a duplicate name from the given sorted array of named 100 /// attributes. Returns std::nullopt if all elements have unique names. 101 static std::optional<NamedAttribute> 102 findDuplicateElement(ArrayRef<NamedAttribute> value) { 103 const std::optional<NamedAttribute> none{std::nullopt}; 104 if (value.size() < 2) 105 return none; 106 107 if (value.size() == 2) 108 return value[0].getName() == value[1].getName() ? value[0] : none; 109 110 const auto *it = std::adjacent_find(value.begin(), value.end(), 111 [](NamedAttribute l, NamedAttribute r) { 112 return l.getName() == r.getName(); 113 }); 114 return it != value.end() ? *it : none; 115 } 116 117 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, 118 SmallVectorImpl<NamedAttribute> &storage) { 119 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); 120 assert(!findDuplicateElement(storage) && 121 "DictionaryAttr element names must be unique"); 122 return isSorted; 123 } 124 125 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { 126 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); 127 assert(!findDuplicateElement(array) && 128 "DictionaryAttr element names must be unique"); 129 return isSorted; 130 } 131 132 std::optional<NamedAttribute> 133 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, 134 bool isSorted) { 135 if (!isSorted) 136 dictionaryAttrSort</*inPlace=*/true>(array, array); 137 return findDuplicateElement(array); 138 } 139 140 DictionaryAttr DictionaryAttr::get(MLIRContext *context, 141 ArrayRef<NamedAttribute> value) { 142 if (value.empty()) 143 return DictionaryAttr::getEmpty(context); 144 145 // We need to sort the element list to canonicalize it. 146 SmallVector<NamedAttribute, 8> storage; 147 if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) 148 value = storage; 149 assert(!findDuplicateElement(value) && 150 "DictionaryAttr element names must be unique"); 151 return Base::get(context, value); 152 } 153 /// Construct a dictionary with an array of values that is known to already be 154 /// sorted by name and uniqued. 155 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context, 156 ArrayRef<NamedAttribute> value) { 157 if (value.empty()) 158 return DictionaryAttr::getEmpty(context); 159 // Ensure that the attribute elements are unique and sorted. 160 assert(llvm::is_sorted( 161 value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && 162 "expected attribute values to be sorted"); 163 assert(!findDuplicateElement(value) && 164 "DictionaryAttr element names must be unique"); 165 return Base::get(context, value); 166 } 167 168 /// Return the specified attribute if present, null otherwise. 169 Attribute DictionaryAttr::get(StringRef name) const { 170 auto it = impl::findAttrSorted(begin(), end(), name); 171 return it.second ? it.first->getValue() : Attribute(); 172 } 173 Attribute DictionaryAttr::get(StringAttr name) const { 174 auto it = impl::findAttrSorted(begin(), end(), name); 175 return it.second ? it.first->getValue() : Attribute(); 176 } 177 178 /// Return the specified named attribute if present, std::nullopt otherwise. 179 std::optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { 180 auto it = impl::findAttrSorted(begin(), end(), name); 181 return it.second ? *it.first : std::optional<NamedAttribute>(); 182 } 183 std::optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const { 184 auto it = impl::findAttrSorted(begin(), end(), name); 185 return it.second ? *it.first : std::optional<NamedAttribute>(); 186 } 187 188 /// Return whether the specified attribute is present. 189 bool DictionaryAttr::contains(StringRef name) const { 190 return impl::findAttrSorted(begin(), end(), name).second; 191 } 192 bool DictionaryAttr::contains(StringAttr name) const { 193 return impl::findAttrSorted(begin(), end(), name).second; 194 } 195 196 DictionaryAttr::iterator DictionaryAttr::begin() const { 197 return getValue().begin(); 198 } 199 DictionaryAttr::iterator DictionaryAttr::end() const { 200 return getValue().end(); 201 } 202 size_t DictionaryAttr::size() const { return getValue().size(); } 203 204 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { 205 return Base::get(context, ArrayRef<NamedAttribute>()); 206 } 207 208 //===----------------------------------------------------------------------===// 209 // StridedLayoutAttr 210 //===----------------------------------------------------------------------===// 211 212 /// Prints a strided layout attribute. 213 void StridedLayoutAttr::print(llvm::raw_ostream &os) const { 214 auto printIntOrQuestion = [&](int64_t value) { 215 if (ShapedType::isDynamic(value)) 216 os << "?"; 217 else 218 os << value; 219 }; 220 221 os << "strided<["; 222 llvm::interleaveComma(getStrides(), os, printIntOrQuestion); 223 os << "]"; 224 225 if (getOffset() != 0) { 226 os << ", offset: "; 227 printIntOrQuestion(getOffset()); 228 } 229 os << ">"; 230 } 231 232 /// Returns true if this layout is static, i.e. the strides and offset all have 233 /// a known value > 0. 234 bool StridedLayoutAttr::hasStaticLayout() const { 235 return !ShapedType::isDynamic(getOffset()) && 236 !ShapedType::isDynamicShape(getStrides()); 237 } 238 239 /// Returns the strided layout as an affine map. 240 AffineMap StridedLayoutAttr::getAffineMap() const { 241 return makeStridedLinearLayoutMap(getStrides(), getOffset(), getContext()); 242 } 243 244 /// Checks that the type-agnostic strided layout invariants are satisfied. 245 LogicalResult 246 StridedLayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError, 247 int64_t offset, ArrayRef<int64_t> strides) { 248 return success(); 249 } 250 251 /// Checks that the type-specific strided layout invariants are satisfied. 252 LogicalResult StridedLayoutAttr::verifyLayout( 253 ArrayRef<int64_t> shape, 254 function_ref<InFlightDiagnostic()> emitError) const { 255 if (shape.size() != getStrides().size()) 256 return emitError() << "expected the number of strides to match the rank"; 257 258 return success(); 259 } 260 261 //===----------------------------------------------------------------------===// 262 // StringAttr 263 //===----------------------------------------------------------------------===// 264 265 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) { 266 return Base::get(context, "", NoneType::get(context)); 267 } 268 269 /// Twine support for StringAttr. 270 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) { 271 // Fast-path empty twine. 272 if (twine.isTriviallyEmpty()) 273 return get(context); 274 SmallVector<char, 32> tempStr; 275 return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context)); 276 } 277 278 /// Twine support for StringAttr. 279 StringAttr StringAttr::get(const Twine &twine, Type type) { 280 SmallVector<char, 32> tempStr; 281 return Base::get(type.getContext(), twine.toStringRef(tempStr), type); 282 } 283 284 StringRef StringAttr::getValue() const { return getImpl()->value; } 285 286 Type StringAttr::getType() const { return getImpl()->type; } 287 288 Dialect *StringAttr::getReferencedDialect() const { 289 return getImpl()->referencedDialect; 290 } 291 292 //===----------------------------------------------------------------------===// 293 // FloatAttr 294 //===----------------------------------------------------------------------===// 295 296 double FloatAttr::getValueAsDouble() const { 297 return getValueAsDouble(getValue()); 298 } 299 double FloatAttr::getValueAsDouble(APFloat value) { 300 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 301 bool losesInfo = false; 302 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 303 &losesInfo); 304 } 305 return value.convertToDouble(); 306 } 307 308 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, 309 Type type, APFloat value) { 310 // Verify that the type is correct. 311 if (!llvm::isa<FloatType>(type)) 312 return emitError() << "expected floating point type"; 313 314 // Verify that the type semantics match that of the value. 315 if (&llvm::cast<FloatType>(type).getFloatSemantics() != 316 &value.getSemantics()) { 317 return emitError() 318 << "FloatAttr type doesn't match the type implied by its value"; 319 } 320 return success(); 321 } 322 323 //===----------------------------------------------------------------------===// 324 // SymbolRefAttr 325 //===----------------------------------------------------------------------===// 326 327 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value, 328 ArrayRef<FlatSymbolRefAttr> nestedRefs) { 329 return get(StringAttr::get(ctx, value), nestedRefs); 330 } 331 332 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { 333 return llvm::cast<FlatSymbolRefAttr>(get(ctx, value, {})); 334 } 335 336 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) { 337 return llvm::cast<FlatSymbolRefAttr>(get(value, {})); 338 } 339 340 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) { 341 auto symName = 342 symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); 343 assert(symName && "value does not have a valid symbol name"); 344 return SymbolRefAttr::get(symName); 345 } 346 347 StringAttr SymbolRefAttr::getLeafReference() const { 348 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); 349 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr(); 350 } 351 352 //===----------------------------------------------------------------------===// 353 // IntegerAttr 354 //===----------------------------------------------------------------------===// 355 356 int64_t IntegerAttr::getInt() const { 357 assert((getType().isIndex() || getType().isSignlessInteger()) && 358 "must be signless integer"); 359 return getValue().getSExtValue(); 360 } 361 362 int64_t IntegerAttr::getSInt() const { 363 assert(getType().isSignedInteger() && "must be signed integer"); 364 return getValue().getSExtValue(); 365 } 366 367 uint64_t IntegerAttr::getUInt() const { 368 assert(getType().isUnsignedInteger() && "must be unsigned integer"); 369 return getValue().getZExtValue(); 370 } 371 372 /// Return the value as an APSInt which carries the signed from the type of 373 /// the attribute. This traps on signless integers types! 374 APSInt IntegerAttr::getAPSInt() const { 375 assert(!getType().isSignlessInteger() && 376 "Signless integers don't carry a sign for APSInt"); 377 return APSInt(getValue(), getType().isUnsignedInteger()); 378 } 379 380 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, 381 Type type, APInt value) { 382 if (IntegerType integerType = llvm::dyn_cast<IntegerType>(type)) { 383 if (integerType.getWidth() != value.getBitWidth()) 384 return emitError() << "integer type bit width (" << integerType.getWidth() 385 << ") doesn't match value bit width (" 386 << value.getBitWidth() << ")"; 387 return success(); 388 } 389 if (llvm::isa<IndexType>(type)) { 390 if (value.getBitWidth() != IndexType::kInternalStorageBitWidth) 391 return emitError() 392 << "value bit width (" << value.getBitWidth() 393 << ") doesn't match index type internal storage bit width (" 394 << IndexType::kInternalStorageBitWidth << ")"; 395 return success(); 396 } 397 return emitError() << "expected integer or index type"; 398 } 399 400 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { 401 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value)); 402 return llvm::cast<BoolAttr>(attr); 403 } 404 405 //===----------------------------------------------------------------------===// 406 // BoolAttr 407 //===----------------------------------------------------------------------===// 408 409 bool BoolAttr::getValue() const { 410 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl); 411 return storage->value.getBoolValue(); 412 } 413 414 bool BoolAttr::classof(Attribute attr) { 415 IntegerAttr intAttr = llvm::dyn_cast<IntegerAttr>(attr); 416 return intAttr && intAttr.getType().isSignlessInteger(1); 417 } 418 419 //===----------------------------------------------------------------------===// 420 // OpaqueAttr 421 //===----------------------------------------------------------------------===// 422 423 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError, 424 StringAttr dialect, StringRef attrData, 425 Type type) { 426 if (!Dialect::isValidNamespace(dialect.strref())) 427 return emitError() << "invalid dialect namespace '" << dialect << "'"; 428 429 // Check that the dialect is actually registered. 430 MLIRContext *context = dialect.getContext(); 431 if (!context->allowsUnregisteredDialects() && 432 !context->getLoadedDialect(dialect.strref())) { 433 return emitError() 434 << "#" << dialect << "<\"" << attrData << "\"> : " << type 435 << " attribute created with unregistered dialect. If this is " 436 "intended, please call allowUnregisteredDialects() on the " 437 "MLIRContext, or use -allow-unregistered-dialect with " 438 "the MLIR opt tool used"; 439 } 440 441 return success(); 442 } 443 444 //===----------------------------------------------------------------------===// 445 // DenseElementsAttr Utilities 446 //===----------------------------------------------------------------------===// 447 448 const char DenseIntOrFPElementsAttrStorage::kSplatTrue = ~0; 449 const char DenseIntOrFPElementsAttrStorage::kSplatFalse = 0; 450 451 /// Get the bitwidth of a dense element type within the buffer. 452 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 453 static size_t getDenseElementStorageWidth(size_t origWidth) { 454 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 455 } 456 static size_t getDenseElementStorageWidth(Type elementType) { 457 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); 458 } 459 460 /// Set a bit to a specific value. 461 static void setBit(char *rawData, size_t bitPos, bool value) { 462 if (value) 463 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 464 else 465 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 466 } 467 468 /// Return the value of the specified bit. 469 static bool getBit(const char *rawData, size_t bitPos) { 470 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 471 } 472 473 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for 474 /// BE format. 475 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, 476 char *result) { 477 assert(llvm::endianness::native == llvm::endianness::big); 478 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 479 480 // Copy the words filled with data. 481 // For example, when `value` has 2 words, the first word is filled with data. 482 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| 483 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 484 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 485 numFilledWords, result); 486 // Convert last word of APInt to LE format and store it in char 487 // array(`valueLE`). 488 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| 489 size_t lastWordPos = numFilledWords; 490 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); 491 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 492 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, 493 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); 494 // Extract actual APInt data from `valueLE`, convert endianness to BE format, 495 // and store it in `result`. 496 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| 497 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 498 valueLE.begin(), result + lastWordPos, 499 (numBytes - lastWordPos) * CHAR_BIT, 1); 500 } 501 502 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE 503 /// format. 504 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, 505 APInt &result) { 506 assert(llvm::endianness::native == llvm::endianness::big); 507 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); 508 509 // Copy the data that fills the word of `result` from `inArray`. 510 // For example, when `result` has 2 words, the first word will be filled with 511 // data. So, the first 8 bytes are copied from `inArray` here. 512 // `inArray` (10 bytes, BE): |abcdefgh|ij| 513 // ==> `result` (2 words, BE): |abcdefgh|--------| 514 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; 515 std::copy_n( 516 inArray, numFilledWords, 517 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 518 519 // Convert array data which will be last word of `result` to LE format, and 520 // store it in char array(`inArrayLE`). 521 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| 522 size_t lastWordPos = numFilledWords; 523 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); 524 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 525 inArray + lastWordPos, inArrayLE.begin(), 526 (numBytes - lastWordPos) * CHAR_BIT, 1); 527 528 // Convert `inArrayLE` to BE format, and store it in last word of `result`. 529 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| 530 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 531 inArrayLE.begin(), 532 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + 533 lastWordPos, 534 APInt::APINT_BITS_PER_WORD, 1); 535 } 536 537 /// Writes value to the bit position `bitPos` in array `rawData`. 538 static void writeBits(char *rawData, size_t bitPos, APInt value) { 539 size_t bitWidth = value.getBitWidth(); 540 541 // If the bitwidth is 1 we just toggle the specific bit. 542 if (bitWidth == 1) 543 return setBit(rawData, bitPos, value.isOne()); 544 545 // Otherwise, the bit position is guaranteed to be byte aligned. 546 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 547 if (llvm::endianness::native == llvm::endianness::big) { 548 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. 549 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 550 // work correctly in BE format. 551 // ex. `value` (2 words including 10 bytes) 552 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| 553 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), 554 rawData + (bitPos / CHAR_BIT)); 555 } else { 556 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 557 llvm::divideCeil(bitWidth, CHAR_BIT), 558 rawData + (bitPos / CHAR_BIT)); 559 } 560 } 561 562 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 563 /// `rawData`. 564 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 565 // Handle a boolean bit position. 566 if (bitWidth == 1) 567 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 568 569 // Otherwise, the bit position must be 8-bit aligned. 570 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 571 APInt result(bitWidth, 0); 572 if (llvm::endianness::native == llvm::endianness::big) { 573 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. 574 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't 575 // work correctly in BE format. 576 // ex. `result` (2 words including 10 bytes) 577 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function 578 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), 579 llvm::divideCeil(bitWidth, CHAR_BIT), result); 580 } else { 581 std::copy_n(rawData + (bitPos / CHAR_BIT), 582 llvm::divideCeil(bitWidth, CHAR_BIT), 583 const_cast<char *>( 584 reinterpret_cast<const char *>(result.getRawData()))); 585 } 586 return result; 587 } 588 589 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has 590 /// the same element count as 'type'. 591 template <typename Values> 592 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 593 return (values.size() == 1) || 594 (type.getNumElements() == static_cast<int64_t>(values.size())); 595 } 596 597 //===----------------------------------------------------------------------===// 598 // DenseElementsAttr Iterators 599 //===----------------------------------------------------------------------===// 600 601 //===----------------------------------------------------------------------===// 602 // AttributeElementIterator 603 604 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 605 DenseElementsAttr attr, size_t index) 606 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, 607 Attribute, Attribute, Attribute>( 608 attr.getAsOpaquePointer(), index) {} 609 610 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 611 auto owner = llvm::cast<DenseElementsAttr>(getFromOpaquePointer(base)); 612 Type eltTy = owner.getElementType(); 613 if (llvm::dyn_cast<IntegerType>(eltTy)) 614 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 615 if (llvm::isa<IndexType>(eltTy)) 616 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 617 if (auto floatEltTy = llvm::dyn_cast<FloatType>(eltTy)) { 618 IntElementIterator intIt(owner, index); 619 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 620 return FloatAttr::get(eltTy, *floatIt); 621 } 622 if (auto complexTy = llvm::dyn_cast<ComplexType>(eltTy)) { 623 auto complexEltTy = complexTy.getElementType(); 624 ComplexIntElementIterator complexIntIt(owner, index); 625 if (llvm::isa<IntegerType>(complexEltTy)) { 626 auto value = *complexIntIt; 627 auto real = IntegerAttr::get(complexEltTy, value.real()); 628 auto imag = IntegerAttr::get(complexEltTy, value.imag()); 629 return ArrayAttr::get(complexTy.getContext(), 630 ArrayRef<Attribute>{real, imag}); 631 } 632 633 ComplexFloatElementIterator complexFloatIt( 634 llvm::cast<FloatType>(complexEltTy).getFloatSemantics(), complexIntIt); 635 auto value = *complexFloatIt; 636 auto real = FloatAttr::get(complexEltTy, value.real()); 637 auto imag = FloatAttr::get(complexEltTy, value.imag()); 638 return ArrayAttr::get(complexTy.getContext(), 639 ArrayRef<Attribute>{real, imag}); 640 } 641 if (llvm::isa<DenseStringElementsAttr>(owner)) { 642 ArrayRef<StringRef> vals = owner.getRawStringData(); 643 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); 644 } 645 llvm_unreachable("unexpected element type"); 646 } 647 648 //===----------------------------------------------------------------------===// 649 // BoolElementIterator 650 651 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 652 DenseElementsAttr attr, size_t dataIndex) 653 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 654 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 655 656 bool DenseElementsAttr::BoolElementIterator::operator*() const { 657 return getBit(getData(), getDataIndex()); 658 } 659 660 //===----------------------------------------------------------------------===// 661 // IntElementIterator 662 663 DenseElementsAttr::IntElementIterator::IntElementIterator( 664 DenseElementsAttr attr, size_t dataIndex) 665 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 666 attr.getRawData().data(), attr.isSplat(), dataIndex), 667 bitWidth(getDenseElementBitWidth(attr.getElementType())) {} 668 669 APInt DenseElementsAttr::IntElementIterator::operator*() const { 670 return readBits(getData(), 671 getDataIndex() * getDenseElementStorageWidth(bitWidth), 672 bitWidth); 673 } 674 675 //===----------------------------------------------------------------------===// 676 // ComplexIntElementIterator 677 678 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( 679 DenseElementsAttr attr, size_t dataIndex) 680 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, 681 std::complex<APInt>, std::complex<APInt>, 682 std::complex<APInt>>( 683 attr.getRawData().data(), attr.isSplat(), dataIndex) { 684 auto complexType = llvm::cast<ComplexType>(attr.getElementType()); 685 bitWidth = getDenseElementBitWidth(complexType.getElementType()); 686 } 687 688 std::complex<APInt> 689 DenseElementsAttr::ComplexIntElementIterator::operator*() const { 690 size_t storageWidth = getDenseElementStorageWidth(bitWidth); 691 size_t offset = getDataIndex() * storageWidth * 2; 692 return {readBits(getData(), offset, bitWidth), 693 readBits(getData(), offset + storageWidth, bitWidth)}; 694 } 695 696 //===----------------------------------------------------------------------===// 697 // DenseArrayAttr 698 //===----------------------------------------------------------------------===// 699 700 LogicalResult 701 DenseArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError, 702 Type elementType, int64_t size, ArrayRef<char> rawData) { 703 if (!elementType.isIntOrIndexOrFloat()) 704 return emitError() << "expected integer or floating point element type"; 705 int64_t dataSize = rawData.size(); 706 int64_t elementSize = 707 llvm::divideCeil(elementType.getIntOrFloatBitWidth(), CHAR_BIT); 708 if (size * elementSize != dataSize) { 709 return emitError() << "expected data size (" << size << " elements, " 710 << elementSize 711 << " bytes each) does not match: " << dataSize 712 << " bytes"; 713 } 714 return success(); 715 } 716 717 namespace { 718 /// Instantiations of this class provide utilities for interacting with native 719 /// data types in the context of DenseArrayAttr. 720 template <size_t width, 721 IntegerType::SignednessSemantics signedness = IntegerType::Signless> 722 struct DenseArrayAttrIntUtil { 723 static bool checkElementType(Type eltType) { 724 auto type = llvm::dyn_cast<IntegerType>(eltType); 725 if (!type || type.getWidth() != width) 726 return false; 727 return type.getSignedness() == signedness; 728 } 729 730 static Type getElementType(MLIRContext *ctx) { 731 return IntegerType::get(ctx, width, signedness); 732 } 733 734 template <typename T> 735 static void printElement(raw_ostream &os, T value) { 736 os << value; 737 } 738 739 template <typename T> 740 static ParseResult parseElement(AsmParser &parser, T &value) { 741 return parser.parseInteger(value); 742 } 743 }; 744 template <typename T> 745 struct DenseArrayAttrUtil; 746 747 /// Specialization for boolean elements to print 'true' and 'false' literals for 748 /// elements. 749 template <> 750 struct DenseArrayAttrUtil<bool> : public DenseArrayAttrIntUtil<1> { 751 static void printElement(raw_ostream &os, bool value) { 752 os << (value ? "true" : "false"); 753 } 754 }; 755 756 /// Specialization for 8-bit integers to ensure values are printed as integers 757 /// and not characters. 758 template <> 759 struct DenseArrayAttrUtil<int8_t> : public DenseArrayAttrIntUtil<8> { 760 static void printElement(raw_ostream &os, int8_t value) { 761 os << static_cast<int>(value); 762 } 763 }; 764 template <> 765 struct DenseArrayAttrUtil<int16_t> : public DenseArrayAttrIntUtil<16> {}; 766 template <> 767 struct DenseArrayAttrUtil<int32_t> : public DenseArrayAttrIntUtil<32> {}; 768 template <> 769 struct DenseArrayAttrUtil<int64_t> : public DenseArrayAttrIntUtil<64> {}; 770 771 /// Specialization for 32-bit floats. 772 template <> 773 struct DenseArrayAttrUtil<float> { 774 static bool checkElementType(Type eltType) { return eltType.isF32(); } 775 static Type getElementType(MLIRContext *ctx) { return Float32Type::get(ctx); } 776 static void printElement(raw_ostream &os, float value) { os << value; } 777 778 /// Parse a double and cast it to a float. 779 static ParseResult parseElement(AsmParser &parser, float &value) { 780 double doubleVal; 781 if (parser.parseFloat(doubleVal)) 782 return failure(); 783 value = doubleVal; 784 return success(); 785 } 786 }; 787 788 /// Specialization for 64-bit floats. 789 template <> 790 struct DenseArrayAttrUtil<double> { 791 static bool checkElementType(Type eltType) { return eltType.isF64(); } 792 static Type getElementType(MLIRContext *ctx) { return Float64Type::get(ctx); } 793 static void printElement(raw_ostream &os, float value) { os << value; } 794 static ParseResult parseElement(AsmParser &parser, double &value) { 795 return parser.parseFloat(value); 796 } 797 }; 798 } // namespace 799 800 template <typename T> 801 void DenseArrayAttrImpl<T>::print(AsmPrinter &printer) const { 802 print(printer.getStream()); 803 } 804 805 template <typename T> 806 void DenseArrayAttrImpl<T>::printWithoutBraces(raw_ostream &os) const { 807 llvm::interleaveComma(asArrayRef(), os, [&](T value) { 808 DenseArrayAttrUtil<T>::printElement(os, value); 809 }); 810 } 811 812 template <typename T> 813 void DenseArrayAttrImpl<T>::print(raw_ostream &os) const { 814 os << "["; 815 printWithoutBraces(os); 816 os << "]"; 817 } 818 819 /// Parse a DenseArrayAttr without the braces: `1, 2, 3` 820 template <typename T> 821 Attribute DenseArrayAttrImpl<T>::parseWithoutBraces(AsmParser &parser, 822 Type odsType) { 823 SmallVector<T> data; 824 if (failed(parser.parseCommaSeparatedList([&]() { 825 T value; 826 if (DenseArrayAttrUtil<T>::parseElement(parser, value)) 827 return failure(); 828 data.push_back(value); 829 return success(); 830 }))) 831 return {}; 832 return get(parser.getContext(), data); 833 } 834 835 /// Parse a DenseArrayAttr: `[ 1, 2, 3 ]` 836 template <typename T> 837 Attribute DenseArrayAttrImpl<T>::parse(AsmParser &parser, Type odsType) { 838 if (parser.parseLSquare()) 839 return {}; 840 // Handle empty list case. 841 if (succeeded(parser.parseOptionalRSquare())) 842 return get(parser.getContext(), {}); 843 Attribute result = parseWithoutBraces(parser, odsType); 844 if (parser.parseRSquare()) 845 return {}; 846 return result; 847 } 848 849 /// Conversion from DenseArrayAttr<T> to ArrayRef<T>. 850 template <typename T> 851 DenseArrayAttrImpl<T>::operator ArrayRef<T>() const { 852 ArrayRef<char> raw = getRawData(); 853 assert((raw.size() % sizeof(T)) == 0); 854 return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()), 855 raw.size() / sizeof(T)); 856 } 857 858 /// Builds a DenseArrayAttr<T> from an ArrayRef<T>. 859 template <typename T> 860 DenseArrayAttrImpl<T> DenseArrayAttrImpl<T>::get(MLIRContext *context, 861 ArrayRef<T> content) { 862 Type elementType = DenseArrayAttrUtil<T>::getElementType(context); 863 auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()), 864 content.size() * sizeof(T)); 865 return llvm::cast<DenseArrayAttrImpl<T>>( 866 Base::get(context, elementType, content.size(), rawArray)); 867 } 868 869 template <typename T> 870 bool DenseArrayAttrImpl<T>::classof(Attribute attr) { 871 if (auto denseArray = llvm::dyn_cast<DenseArrayAttr>(attr)) 872 return DenseArrayAttrUtil<T>::checkElementType(denseArray.getElementType()); 873 return false; 874 } 875 876 namespace mlir { 877 namespace detail { 878 // Explicit instantiation for all the supported DenseArrayAttr. 879 template class DenseArrayAttrImpl<bool>; 880 template class DenseArrayAttrImpl<int8_t>; 881 template class DenseArrayAttrImpl<int16_t>; 882 template class DenseArrayAttrImpl<int32_t>; 883 template class DenseArrayAttrImpl<int64_t>; 884 template class DenseArrayAttrImpl<float>; 885 template class DenseArrayAttrImpl<double>; 886 } // namespace detail 887 } // namespace mlir 888 889 //===----------------------------------------------------------------------===// 890 // DenseElementsAttr 891 //===----------------------------------------------------------------------===// 892 893 /// Method for support type inquiry through isa, cast and dyn_cast. 894 bool DenseElementsAttr::classof(Attribute attr) { 895 return llvm::isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(attr); 896 } 897 898 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 899 ArrayRef<Attribute> values) { 900 assert(hasSameElementsOrSplat(type, values)); 901 902 Type eltType = type.getElementType(); 903 904 // Take care complex type case first. 905 if (auto complexType = llvm::dyn_cast<ComplexType>(eltType)) { 906 if (complexType.getElementType().isIntOrIndex()) { 907 SmallVector<std::complex<APInt>> complexValues; 908 complexValues.reserve(values.size()); 909 for (Attribute attr : values) { 910 assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex"); 911 auto arrayAttr = llvm::cast<ArrayAttr>(attr); 912 assert(arrayAttr.size() == 2 && "expected 2 element for complex"); 913 auto attr0 = arrayAttr[0]; 914 auto attr1 = arrayAttr[1]; 915 complexValues.push_back( 916 std::complex<APInt>(llvm::cast<IntegerAttr>(attr0).getValue(), 917 llvm::cast<IntegerAttr>(attr1).getValue())); 918 } 919 return DenseElementsAttr::get(type, complexValues); 920 } 921 // Must be float. 922 SmallVector<std::complex<APFloat>> complexValues; 923 complexValues.reserve(values.size()); 924 for (Attribute attr : values) { 925 assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex"); 926 auto arrayAttr = llvm::cast<ArrayAttr>(attr); 927 assert(arrayAttr.size() == 2 && "expected 2 element for complex"); 928 auto attr0 = arrayAttr[0]; 929 auto attr1 = arrayAttr[1]; 930 complexValues.push_back( 931 std::complex<APFloat>(llvm::cast<FloatAttr>(attr0).getValue(), 932 llvm::cast<FloatAttr>(attr1).getValue())); 933 } 934 return DenseElementsAttr::get(type, complexValues); 935 } 936 937 // If the element type is not based on int/float/index, assume it is a string 938 // type. 939 if (!eltType.isIntOrIndexOrFloat()) { 940 SmallVector<StringRef, 8> stringValues; 941 stringValues.reserve(values.size()); 942 for (Attribute attr : values) { 943 assert(llvm::isa<StringAttr>(attr) && 944 "expected string value for non integer/index/float element"); 945 stringValues.push_back(llvm::cast<StringAttr>(attr).getValue()); 946 } 947 return get(type, stringValues); 948 } 949 950 // Otherwise, get the raw storage width to use for the allocation. 951 size_t bitWidth = getDenseElementBitWidth(eltType); 952 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 953 954 // Compress the attribute values into a character buffer. 955 SmallVector<char, 8> data( 956 llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT)); 957 APInt intVal; 958 for (unsigned i = 0, e = values.size(); i < e; ++i) { 959 if (auto floatAttr = llvm::dyn_cast<FloatAttr>(values[i])) { 960 assert(floatAttr.getType() == eltType && 961 "expected float attribute type to equal element type"); 962 intVal = floatAttr.getValue().bitcastToAPInt(); 963 } else { 964 auto intAttr = llvm::cast<IntegerAttr>(values[i]); 965 assert(intAttr.getType() == eltType && 966 "expected integer attribute type to equal element type"); 967 intVal = intAttr.getValue(); 968 } 969 970 assert(intVal.getBitWidth() == bitWidth && 971 "expected value to have same bitwidth as element type"); 972 writeBits(data.data(), i * storageBitWidth, intVal); 973 } 974 975 // Handle the special encoding of splat of bool. 976 if (values.size() == 1 && eltType.isInteger(1)) 977 data[0] = data[0] ? -1 : 0; 978 979 return DenseIntOrFPElementsAttr::getRaw(type, data); 980 } 981 982 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 983 ArrayRef<bool> values) { 984 assert(hasSameElementsOrSplat(type, values)); 985 assert(type.getElementType().isInteger(1)); 986 987 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 988 989 if (!values.empty()) { 990 bool isSplat = true; 991 bool firstValue = values[0]; 992 for (int i = 0, e = values.size(); i != e; ++i) { 993 isSplat &= values[i] == firstValue; 994 setBit(buff.data(), i, values[i]); 995 } 996 997 // Splat of bool is encoded as a byte with all-ones in it. 998 if (isSplat) { 999 buff.resize(1); 1000 buff[0] = values[0] ? -1 : 0; 1001 } 1002 } 1003 1004 return DenseIntOrFPElementsAttr::getRaw(type, buff); 1005 } 1006 1007 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 1008 ArrayRef<StringRef> values) { 1009 assert(!type.getElementType().isIntOrFloat()); 1010 return DenseStringElementsAttr::get(type, values); 1011 } 1012 1013 /// Constructs a dense integer elements attribute from an array of APInt 1014 /// values. Each APInt value is expected to have the same bitwidth as the 1015 /// element type of 'type'. 1016 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 1017 ArrayRef<APInt> values) { 1018 assert(type.getElementType().isIntOrIndex()); 1019 assert(hasSameElementsOrSplat(type, values)); 1020 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 1021 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); 1022 } 1023 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 1024 ArrayRef<std::complex<APInt>> values) { 1025 ComplexType complex = llvm::cast<ComplexType>(type.getElementType()); 1026 assert(llvm::isa<IntegerType>(complex.getElementType())); 1027 assert(hasSameElementsOrSplat(type, values)); 1028 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 1029 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), 1030 values.size() * 2); 1031 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals); 1032 } 1033 1034 // Constructs a dense float elements attribute from an array of APFloat 1035 // values. Each APFloat value is expected to have the same bitwidth as the 1036 // element type of 'type'. 1037 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 1038 ArrayRef<APFloat> values) { 1039 assert(llvm::isa<FloatType>(type.getElementType())); 1040 assert(hasSameElementsOrSplat(type, values)); 1041 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); 1042 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); 1043 } 1044 DenseElementsAttr 1045 DenseElementsAttr::get(ShapedType type, 1046 ArrayRef<std::complex<APFloat>> values) { 1047 ComplexType complex = llvm::cast<ComplexType>(type.getElementType()); 1048 assert(llvm::isa<FloatType>(complex.getElementType())); 1049 assert(hasSameElementsOrSplat(type, values)); 1050 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), 1051 values.size() * 2); 1052 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; 1053 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals); 1054 } 1055 1056 /// Construct a dense elements attribute from a raw buffer representing the 1057 /// data for this attribute. Users should generally not use this methods as 1058 /// the expected buffer format may not be a form the user expects. 1059 DenseElementsAttr 1060 DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) { 1061 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer); 1062 } 1063 1064 /// Returns true if the given buffer is a valid raw buffer for the given type. 1065 bool DenseElementsAttr::isValidRawBuffer(ShapedType type, 1066 ArrayRef<char> rawBuffer, 1067 bool &detectedSplat) { 1068 size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); 1069 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; 1070 int64_t numElements = type.getNumElements(); 1071 1072 // The initializer is always a splat if the result type has a single element. 1073 detectedSplat = numElements == 1; 1074 1075 // Storage width of 1 is special as it is packed by the bit. 1076 if (storageWidth == 1) { 1077 // Check for a splat, or a buffer equal to the number of elements which 1078 // consists of either all 0's or all 1's. 1079 if (rawBuffer.size() == 1) { 1080 auto rawByte = static_cast<uint8_t>(rawBuffer[0]); 1081 if (rawByte == 0 || rawByte == 0xff) { 1082 detectedSplat = true; 1083 return true; 1084 } 1085 } 1086 1087 // This is a valid non-splat buffer if it has the right size. 1088 return rawBufferWidth == llvm::alignTo<8>(numElements); 1089 } 1090 1091 // All other types are 8-bit aligned, so we can just check the buffer width 1092 // to know if only a single initializer element was passed in. 1093 if (rawBufferWidth == storageWidth) { 1094 detectedSplat = true; 1095 return true; 1096 } 1097 1098 // The raw buffer is valid if it has the right size. 1099 return rawBufferWidth == storageWidth * numElements; 1100 } 1101 1102 /// Check the information for a C++ data type, check if this type is valid for 1103 /// the current attribute. This method is used to verify specific type 1104 /// invariants that the templatized 'getValues' method cannot. 1105 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, 1106 bool isSigned) { 1107 // Make sure that the data element size is the same as the type element width. 1108 auto denseEltBitWidth = getDenseElementBitWidth(type); 1109 auto dataSize = static_cast<size_t>(dataEltSize * CHAR_BIT); 1110 if (denseEltBitWidth != dataSize) { 1111 LLVM_DEBUG(llvm::dbgs() << "expected dense element bit width " 1112 << denseEltBitWidth << " to match data size " 1113 << dataSize << " for type " << type << "\n"); 1114 return false; 1115 } 1116 1117 // Check that the element type is either float or integer or index. 1118 if (!isInt) { 1119 bool valid = llvm::isa<FloatType>(type); 1120 if (!valid) 1121 LLVM_DEBUG(llvm::dbgs() 1122 << "expected float type when isInt is false, but found " 1123 << type << "\n"); 1124 return valid; 1125 } 1126 if (type.isIndex()) 1127 return true; 1128 1129 auto intType = llvm::dyn_cast<IntegerType>(type); 1130 if (!intType) { 1131 LLVM_DEBUG(llvm::dbgs() 1132 << "expected integer type when isInt is true, but found " << type 1133 << "\n"); 1134 return false; 1135 } 1136 1137 // Make sure signedness semantics is consistent. 1138 if (intType.isSignless()) 1139 return true; 1140 1141 bool valid = intType.isSigned() == isSigned; 1142 if (!valid) 1143 LLVM_DEBUG(llvm::dbgs() << "expected signedness " << isSigned 1144 << " to match type " << type << "\n"); 1145 return valid; 1146 } 1147 1148 /// Defaults down the subclass implementation. 1149 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, 1150 ArrayRef<char> data, 1151 int64_t dataEltSize, 1152 bool isInt, bool isSigned) { 1153 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, 1154 isSigned); 1155 } 1156 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 1157 ArrayRef<char> data, 1158 int64_t dataEltSize, 1159 bool isInt, 1160 bool isSigned) { 1161 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, 1162 isInt, isSigned); 1163 } 1164 1165 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, 1166 bool isSigned) const { 1167 return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned); 1168 } 1169 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, 1170 bool isSigned) const { 1171 return ::isValidIntOrFloat( 1172 llvm::cast<ComplexType>(getElementType()).getElementType(), 1173 dataEltSize / 2, isInt, isSigned); 1174 } 1175 1176 /// Returns true if this attribute corresponds to a splat, i.e. if all element 1177 /// values are the same. 1178 bool DenseElementsAttr::isSplat() const { 1179 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; 1180 } 1181 1182 /// Return if the given complex type has an integer element type. 1183 static bool isComplexOfIntType(Type type) { 1184 return llvm::isa<IntegerType>(llvm::cast<ComplexType>(type).getElementType()); 1185 } 1186 1187 auto DenseElementsAttr::tryGetComplexIntValues() const 1188 -> FailureOr<iterator_range_impl<ComplexIntElementIterator>> { 1189 if (!isComplexOfIntType(getElementType())) 1190 return failure(); 1191 return iterator_range_impl<ComplexIntElementIterator>( 1192 getType(), ComplexIntElementIterator(*this, 0), 1193 ComplexIntElementIterator(*this, getNumElements())); 1194 } 1195 1196 auto DenseElementsAttr::tryGetFloatValues() const 1197 -> FailureOr<iterator_range_impl<FloatElementIterator>> { 1198 auto eltTy = llvm::dyn_cast<FloatType>(getElementType()); 1199 if (!eltTy) 1200 return failure(); 1201 const auto &elementSemantics = eltTy.getFloatSemantics(); 1202 return iterator_range_impl<FloatElementIterator>( 1203 getType(), FloatElementIterator(elementSemantics, raw_int_begin()), 1204 FloatElementIterator(elementSemantics, raw_int_end())); 1205 } 1206 1207 auto DenseElementsAttr::tryGetComplexFloatValues() const 1208 -> FailureOr<iterator_range_impl<ComplexFloatElementIterator>> { 1209 auto complexTy = llvm::dyn_cast<ComplexType>(getElementType()); 1210 if (!complexTy) 1211 return failure(); 1212 auto eltTy = llvm::dyn_cast<FloatType>(complexTy.getElementType()); 1213 if (!eltTy) 1214 return failure(); 1215 const auto &semantics = eltTy.getFloatSemantics(); 1216 return iterator_range_impl<ComplexFloatElementIterator>( 1217 getType(), {semantics, {*this, 0}}, 1218 {semantics, {*this, static_cast<size_t>(getNumElements())}}); 1219 } 1220 1221 /// Return the raw storage data held by this attribute. 1222 ArrayRef<char> DenseElementsAttr::getRawData() const { 1223 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data; 1224 } 1225 1226 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { 1227 return static_cast<DenseStringElementsAttrStorage *>(impl)->data; 1228 } 1229 1230 /// Return a new DenseElementsAttr that has the same data as the current 1231 /// attribute, but has been reshaped to 'newType'. The new type must have the 1232 /// same total number of elements as well as element type. 1233 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 1234 ShapedType curType = getType(); 1235 if (curType == newType) 1236 return *this; 1237 1238 assert(newType.getElementType() == curType.getElementType() && 1239 "expected the same element type"); 1240 assert(newType.getNumElements() == curType.getNumElements() && 1241 "expected the same number of elements"); 1242 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData()); 1243 } 1244 1245 DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) { 1246 assert(isSplat() && "expected a splat type"); 1247 1248 ShapedType curType = getType(); 1249 if (curType == newType) 1250 return *this; 1251 1252 assert(newType.getElementType() == curType.getElementType() && 1253 "expected the same element type"); 1254 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData()); 1255 } 1256 1257 /// Return a new DenseElementsAttr that has the same data as the current 1258 /// attribute, but has bitcast elements such that it is now 'newType'. The new 1259 /// type must have the same shape and element types of the same bitwidth as the 1260 /// current type. 1261 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) { 1262 ShapedType curType = getType(); 1263 Type curElType = curType.getElementType(); 1264 if (curElType == newElType) 1265 return *this; 1266 1267 assert(getDenseElementBitWidth(newElType) == 1268 getDenseElementBitWidth(curElType) && 1269 "expected element types with the same bitwidth"); 1270 return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType), 1271 getRawData()); 1272 } 1273 1274 DenseElementsAttr 1275 DenseElementsAttr::mapValues(Type newElementType, 1276 function_ref<APInt(const APInt &)> mapping) const { 1277 return llvm::cast<DenseIntElementsAttr>(*this).mapValues(newElementType, 1278 mapping); 1279 } 1280 1281 DenseElementsAttr DenseElementsAttr::mapValues( 1282 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1283 return llvm::cast<DenseFPElementsAttr>(*this).mapValues(newElementType, 1284 mapping); 1285 } 1286 1287 ShapedType DenseElementsAttr::getType() const { 1288 return static_cast<const DenseElementsAttributeStorage *>(impl)->type; 1289 } 1290 1291 Type DenseElementsAttr::getElementType() const { 1292 return getType().getElementType(); 1293 } 1294 1295 int64_t DenseElementsAttr::getNumElements() const { 1296 return getType().getNumElements(); 1297 } 1298 1299 //===----------------------------------------------------------------------===// 1300 // DenseIntOrFPElementsAttr 1301 //===----------------------------------------------------------------------===// 1302 1303 /// Utility method to write a range of APInt values to a buffer. 1304 template <typename APRangeT> 1305 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, 1306 APRangeT &&values) { 1307 size_t numValues = llvm::size(values); 1308 data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT)); 1309 size_t offset = 0; 1310 for (auto it = values.begin(), e = values.end(); it != e; 1311 ++it, offset += storageWidth) { 1312 assert((*it).getBitWidth() <= storageWidth); 1313 writeBits(data.data(), offset, *it); 1314 } 1315 1316 // Handle the special encoding of splat of a boolean. 1317 if (numValues == 1 && (*values.begin()).getBitWidth() == 1) 1318 data[0] = data[0] ? -1 : 0; 1319 } 1320 1321 /// Constructs a dense elements attribute from an array of raw APFloat values. 1322 /// Each APFloat value is expected to have the same bitwidth as the element 1323 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1324 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1325 size_t storageWidth, 1326 ArrayRef<APFloat> values) { 1327 std::vector<char> data; 1328 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; 1329 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); 1330 return DenseIntOrFPElementsAttr::getRaw(type, data); 1331 } 1332 1333 /// Constructs a dense elements attribute from an array of raw APInt values. 1334 /// Each APInt value is expected to have the same bitwidth as the element type 1335 /// of 'type'. 1336 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1337 size_t storageWidth, 1338 ArrayRef<APInt> values) { 1339 std::vector<char> data; 1340 writeAPIntsToBuffer(storageWidth, data, values); 1341 return DenseIntOrFPElementsAttr::getRaw(type, data); 1342 } 1343 1344 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, 1345 ArrayRef<char> data) { 1346 assert(type.hasStaticShape() && "type must have static shape"); 1347 bool isSplat = false; 1348 bool isValid = isValidRawBuffer(type, data, isSplat); 1349 assert(isValid); 1350 (void)isValid; 1351 return Base::get(type.getContext(), type, data, isSplat); 1352 } 1353 1354 /// Overload of the raw 'get' method that asserts that the given type is of 1355 /// complex type. This method is used to verify type invariants that the 1356 /// templatized 'get' method cannot. 1357 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, 1358 ArrayRef<char> data, 1359 int64_t dataEltSize, 1360 bool isInt, 1361 bool isSigned) { 1362 assert(::isValidIntOrFloat( 1363 llvm::cast<ComplexType>(type.getElementType()).getElementType(), 1364 dataEltSize / 2, isInt, isSigned) && 1365 "Try re-running with -debug-only=builtinattributes"); 1366 1367 int64_t numElements = data.size() / dataEltSize; 1368 (void)numElements; 1369 assert(numElements == 1 || numElements == type.getNumElements()); 1370 return getRaw(type, data); 1371 } 1372 1373 /// Overload of the 'getRaw' method that asserts that the given type is of 1374 /// integer type. This method is used to verify type invariants that the 1375 /// templatized 'get' method cannot. 1376 DenseElementsAttr 1377 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, 1378 int64_t dataEltSize, bool isInt, 1379 bool isSigned) { 1380 assert(::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, 1381 isSigned) && 1382 "Try re-running with -debug-only=builtinattributes"); 1383 1384 int64_t numElements = data.size() / dataEltSize; 1385 assert(numElements == 1 || numElements == type.getNumElements()); 1386 (void)numElements; 1387 return getRaw(type, data); 1388 } 1389 1390 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( 1391 const char *inRawData, char *outRawData, size_t elementBitWidth, 1392 size_t numElements) { 1393 using llvm::support::ulittle16_t; 1394 using llvm::support::ulittle32_t; 1395 using llvm::support::ulittle64_t; 1396 1397 assert(llvm::endianness::native == llvm::endianness::big); 1398 // NOLINT to avoid warning message about replacing by static_assert() 1399 1400 // Following std::copy_n always converts endianness on BE machine. 1401 switch (elementBitWidth) { 1402 case 16: { 1403 const ulittle16_t *inRawDataPos = 1404 reinterpret_cast<const ulittle16_t *>(inRawData); 1405 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); 1406 std::copy_n(inRawDataPos, numElements, outDataPos); 1407 break; 1408 } 1409 case 32: { 1410 const ulittle32_t *inRawDataPos = 1411 reinterpret_cast<const ulittle32_t *>(inRawData); 1412 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); 1413 std::copy_n(inRawDataPos, numElements, outDataPos); 1414 break; 1415 } 1416 case 64: { 1417 const ulittle64_t *inRawDataPos = 1418 reinterpret_cast<const ulittle64_t *>(inRawData); 1419 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); 1420 std::copy_n(inRawDataPos, numElements, outDataPos); 1421 break; 1422 } 1423 default: { 1424 size_t nBytes = elementBitWidth / CHAR_BIT; 1425 for (size_t i = 0; i < nBytes; i++) 1426 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); 1427 break; 1428 } 1429 } 1430 } 1431 1432 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( 1433 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, 1434 ShapedType type) { 1435 size_t numElements = type.getNumElements(); 1436 Type elementType = type.getElementType(); 1437 if (ComplexType complexTy = llvm::dyn_cast<ComplexType>(elementType)) { 1438 elementType = complexTy.getElementType(); 1439 numElements = numElements * 2; 1440 } 1441 size_t elementBitWidth = getDenseElementStorageWidth(elementType); 1442 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && 1443 inRawData.size() <= outRawData.size()); 1444 if (elementBitWidth <= CHAR_BIT) 1445 std::memcpy(outRawData.begin(), inRawData.begin(), inRawData.size()); 1446 else 1447 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), 1448 elementBitWidth, numElements); 1449 } 1450 1451 //===----------------------------------------------------------------------===// 1452 // DenseFPElementsAttr 1453 //===----------------------------------------------------------------------===// 1454 1455 template <typename Fn, typename Attr> 1456 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 1457 Type newElementType, 1458 llvm::SmallVectorImpl<char> &data) { 1459 size_t bitWidth = getDenseElementBitWidth(newElementType); 1460 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 1461 1462 ShapedType newArrayType = inType.cloneWith(inType.getShape(), newElementType); 1463 1464 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 1465 data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT)); 1466 1467 // Functor used to process a single element value of the attribute. 1468 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 1469 auto newInt = mapping(value); 1470 assert(newInt.getBitWidth() == bitWidth); 1471 writeBits(data.data(), index * storageBitWidth, newInt); 1472 }; 1473 1474 // Check for the splat case. 1475 if (attr.isSplat()) { 1476 if (bitWidth == 1) { 1477 // Handle the special encoding of splat of bool. 1478 data[0] = mapping(*attr.begin()).isZero() ? 0 : -1; 1479 } else { 1480 processElt(*attr.begin(), /*index=*/0); 1481 } 1482 return newArrayType; 1483 } 1484 1485 // Otherwise, process all of the element values. 1486 uint64_t elementIdx = 0; 1487 for (auto value : attr) 1488 processElt(value, elementIdx++); 1489 return newArrayType; 1490 } 1491 1492 DenseElementsAttr DenseFPElementsAttr::mapValues( 1493 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { 1494 llvm::SmallVector<char, 8> elementData; 1495 auto newArrayType = 1496 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1497 1498 return getRaw(newArrayType, elementData); 1499 } 1500 1501 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1502 bool DenseFPElementsAttr::classof(Attribute attr) { 1503 if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(attr)) 1504 return llvm::isa<FloatType>(denseAttr.getType().getElementType()); 1505 return false; 1506 } 1507 1508 //===----------------------------------------------------------------------===// 1509 // DenseIntElementsAttr 1510 //===----------------------------------------------------------------------===// 1511 1512 DenseElementsAttr DenseIntElementsAttr::mapValues( 1513 Type newElementType, function_ref<APInt(const APInt &)> mapping) const { 1514 llvm::SmallVector<char, 8> elementData; 1515 auto newArrayType = 1516 mappingHelper(mapping, *this, getType(), newElementType, elementData); 1517 return getRaw(newArrayType, elementData); 1518 } 1519 1520 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1521 bool DenseIntElementsAttr::classof(Attribute attr) { 1522 if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(attr)) 1523 return denseAttr.getType().getElementType().isIntOrIndex(); 1524 return false; 1525 } 1526 1527 //===----------------------------------------------------------------------===// 1528 // DenseResourceElementsAttr 1529 //===----------------------------------------------------------------------===// 1530 1531 DenseResourceElementsAttr 1532 DenseResourceElementsAttr::get(ShapedType type, 1533 DenseResourceElementsHandle handle) { 1534 return Base::get(type.getContext(), type, handle); 1535 } 1536 1537 DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type, 1538 StringRef blobName, 1539 AsmResourceBlob blob) { 1540 // Extract the builtin dialect resource manager from context and construct a 1541 // handle by inserting a new resource using the provided blob. 1542 auto &manager = 1543 DenseResourceElementsHandle::getManagerInterface(type.getContext()); 1544 return get(type, manager.insert(blobName, std::move(blob))); 1545 } 1546 1547 ArrayRef<char> DenseResourceElementsAttr::getData() { 1548 if (AsmResourceBlob *blob = this->getRawHandle().getBlob()) 1549 return blob->getDataAs<char>(); 1550 return {}; 1551 } 1552 1553 //===----------------------------------------------------------------------===// 1554 // DenseResourceElementsAttrBase 1555 1556 namespace { 1557 /// Instantiations of this class provide utilities for interacting with native 1558 /// data types in the context of DenseResourceElementsAttr. 1559 template <typename T> 1560 struct DenseResourceAttrUtil; 1561 template <size_t width, bool isSigned> 1562 struct DenseResourceElementsAttrIntUtil { 1563 static bool checkElementType(Type eltType) { 1564 IntegerType type = llvm::dyn_cast<IntegerType>(eltType); 1565 if (!type || type.getWidth() != width) 1566 return false; 1567 return isSigned ? !type.isUnsigned() : !type.isSigned(); 1568 } 1569 }; 1570 template <> 1571 struct DenseResourceAttrUtil<bool> { 1572 static bool checkElementType(Type eltType) { 1573 return eltType.isSignlessInteger(1); 1574 } 1575 }; 1576 template <> 1577 struct DenseResourceAttrUtil<int8_t> 1578 : public DenseResourceElementsAttrIntUtil<8, true> {}; 1579 template <> 1580 struct DenseResourceAttrUtil<uint8_t> 1581 : public DenseResourceElementsAttrIntUtil<8, false> {}; 1582 template <> 1583 struct DenseResourceAttrUtil<int16_t> 1584 : public DenseResourceElementsAttrIntUtil<16, true> {}; 1585 template <> 1586 struct DenseResourceAttrUtil<uint16_t> 1587 : public DenseResourceElementsAttrIntUtil<16, false> {}; 1588 template <> 1589 struct DenseResourceAttrUtil<int32_t> 1590 : public DenseResourceElementsAttrIntUtil<32, true> {}; 1591 template <> 1592 struct DenseResourceAttrUtil<uint32_t> 1593 : public DenseResourceElementsAttrIntUtil<32, false> {}; 1594 template <> 1595 struct DenseResourceAttrUtil<int64_t> 1596 : public DenseResourceElementsAttrIntUtil<64, true> {}; 1597 template <> 1598 struct DenseResourceAttrUtil<uint64_t> 1599 : public DenseResourceElementsAttrIntUtil<64, false> {}; 1600 template <> 1601 struct DenseResourceAttrUtil<float> { 1602 static bool checkElementType(Type eltType) { return eltType.isF32(); } 1603 }; 1604 template <> 1605 struct DenseResourceAttrUtil<double> { 1606 static bool checkElementType(Type eltType) { return eltType.isF64(); } 1607 }; 1608 } // namespace 1609 1610 template <typename T> 1611 DenseResourceElementsAttrBase<T> 1612 DenseResourceElementsAttrBase<T>::get(ShapedType type, StringRef blobName, 1613 AsmResourceBlob blob) { 1614 // Check that the blob is in the form we were expecting. 1615 assert(blob.getDataAlignment() == alignof(T) && 1616 "alignment mismatch between expected alignment and blob alignment"); 1617 assert(((blob.getData().size() % sizeof(T)) == 0) && 1618 "size mismatch between expected element width and blob size"); 1619 assert(DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) && 1620 "invalid shape element type for provided type `T`"); 1621 return llvm::cast<DenseResourceElementsAttrBase<T>>( 1622 DenseResourceElementsAttr::get(type, blobName, std::move(blob))); 1623 } 1624 1625 template <typename T> 1626 std::optional<ArrayRef<T>> 1627 DenseResourceElementsAttrBase<T>::tryGetAsArrayRef() const { 1628 if (AsmResourceBlob *blob = this->getRawHandle().getBlob()) 1629 return blob->template getDataAs<T>(); 1630 return std::nullopt; 1631 } 1632 1633 template <typename T> 1634 bool DenseResourceElementsAttrBase<T>::classof(Attribute attr) { 1635 auto resourceAttr = llvm::dyn_cast<DenseResourceElementsAttr>(attr); 1636 return resourceAttr && DenseResourceAttrUtil<T>::checkElementType( 1637 resourceAttr.getElementType()); 1638 } 1639 1640 namespace mlir { 1641 namespace detail { 1642 // Explicit instantiation for all the supported DenseResourceElementsAttr. 1643 template class DenseResourceElementsAttrBase<bool>; 1644 template class DenseResourceElementsAttrBase<int8_t>; 1645 template class DenseResourceElementsAttrBase<int16_t>; 1646 template class DenseResourceElementsAttrBase<int32_t>; 1647 template class DenseResourceElementsAttrBase<int64_t>; 1648 template class DenseResourceElementsAttrBase<uint8_t>; 1649 template class DenseResourceElementsAttrBase<uint16_t>; 1650 template class DenseResourceElementsAttrBase<uint32_t>; 1651 template class DenseResourceElementsAttrBase<uint64_t>; 1652 template class DenseResourceElementsAttrBase<float>; 1653 template class DenseResourceElementsAttrBase<double>; 1654 } // namespace detail 1655 } // namespace mlir 1656 1657 //===----------------------------------------------------------------------===// 1658 // SparseElementsAttr 1659 //===----------------------------------------------------------------------===// 1660 1661 /// Get a zero APFloat for the given sparse attribute. 1662 APFloat SparseElementsAttr::getZeroAPFloat() const { 1663 auto eltType = llvm::cast<FloatType>(getElementType()); 1664 return APFloat(eltType.getFloatSemantics()); 1665 } 1666 1667 /// Get a zero APInt for the given sparse attribute. 1668 APInt SparseElementsAttr::getZeroAPInt() const { 1669 auto eltType = llvm::cast<IntegerType>(getElementType()); 1670 return APInt::getZero(eltType.getWidth()); 1671 } 1672 1673 /// Get a zero attribute for the given attribute type. 1674 Attribute SparseElementsAttr::getZeroAttr() const { 1675 auto eltType = getElementType(); 1676 1677 // Handle floating point elements. 1678 if (llvm::isa<FloatType>(eltType)) 1679 return FloatAttr::get(eltType, 0); 1680 1681 // Handle complex elements. 1682 if (auto complexTy = llvm::dyn_cast<ComplexType>(eltType)) { 1683 auto eltType = complexTy.getElementType(); 1684 Attribute zero; 1685 if (llvm::isa<FloatType>(eltType)) 1686 zero = FloatAttr::get(eltType, 0); 1687 else // must be integer 1688 zero = IntegerAttr::get(eltType, 0); 1689 return ArrayAttr::get(complexTy.getContext(), 1690 ArrayRef<Attribute>{zero, zero}); 1691 } 1692 1693 // Handle string type. 1694 if (llvm::isa<DenseStringElementsAttr>(getValues())) 1695 return StringAttr::get("", eltType); 1696 1697 // Otherwise, this is an integer. 1698 return IntegerAttr::get(eltType, 0); 1699 } 1700 1701 /// Flatten, and return, all of the sparse indices in this attribute in 1702 /// row-major order. 1703 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1704 std::vector<ptrdiff_t> flatSparseIndices; 1705 1706 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1707 // as a 1-D index array. 1708 auto sparseIndices = getIndices(); 1709 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1710 if (sparseIndices.isSplat()) { 1711 SmallVector<uint64_t, 8> indices(getType().getRank(), 1712 *sparseIndexValues.begin()); 1713 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1714 return flatSparseIndices; 1715 } 1716 1717 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1718 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1719 size_t rank = getType().getRank(); 1720 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1721 flatSparseIndices.push_back(getFlattenedIndex( 1722 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1723 return flatSparseIndices; 1724 } 1725 1726 LogicalResult 1727 SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1728 ShapedType type, DenseIntElementsAttr sparseIndices, 1729 DenseElementsAttr values) { 1730 ShapedType valuesType = values.getType(); 1731 if (valuesType.getRank() != 1) 1732 return emitError() << "expected 1-d tensor for sparse element values"; 1733 1734 // Verify the indices and values shape. 1735 ShapedType indicesType = sparseIndices.getType(); 1736 auto emitShapeError = [&]() { 1737 return emitError() << "expected shape ([" << type.getShape() 1738 << "]); inferred shape of indices literal ([" 1739 << indicesType.getShape() 1740 << "]); inferred shape of values literal ([" 1741 << valuesType.getShape() << "])"; 1742 }; 1743 // Verify indices shape. 1744 size_t rank = type.getRank(), indicesRank = indicesType.getRank(); 1745 if (indicesRank == 2) { 1746 if (indicesType.getDimSize(1) != static_cast<int64_t>(rank)) 1747 return emitShapeError(); 1748 } else if (indicesRank != 1 || rank != 1) { 1749 return emitShapeError(); 1750 } 1751 // Verify the values shape. 1752 int64_t numSparseIndices = indicesType.getDimSize(0); 1753 if (numSparseIndices != valuesType.getDimSize(0)) 1754 return emitShapeError(); 1755 1756 // Verify that the sparse indices are within the value shape. 1757 auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) { 1758 return emitError() 1759 << "sparse index #" << indexNum 1760 << " is not contained within the value shape, with index=[" << index 1761 << "], and type=" << type; 1762 }; 1763 1764 // Handle the case where the index values are a splat. 1765 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1766 if (sparseIndices.isSplat()) { 1767 SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin()); 1768 if (!ElementsAttr::isValidIndex(type, indices)) 1769 return emitIndexError(0, indices); 1770 return success(); 1771 } 1772 1773 // Otherwise, reinterpret each index as an ArrayRef. 1774 for (size_t i = 0, e = numSparseIndices; i != e; ++i) { 1775 ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank), 1776 rank); 1777 if (!ElementsAttr::isValidIndex(type, index)) 1778 return emitIndexError(i, index); 1779 } 1780 1781 return success(); 1782 } 1783 1784 //===----------------------------------------------------------------------===// 1785 // DistinctAttr 1786 //===----------------------------------------------------------------------===// 1787 1788 DistinctAttr DistinctAttr::create(Attribute referencedAttr) { 1789 return Base::get(referencedAttr.getContext(), referencedAttr); 1790 } 1791 1792 Attribute DistinctAttr::getReferencedAttr() const { 1793 return getImpl()->referencedAttr; 1794 } 1795 1796 //===----------------------------------------------------------------------===// 1797 // Attribute Utilities 1798 //===----------------------------------------------------------------------===// 1799 1800 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 1801 int64_t offset, 1802 MLIRContext *context) { 1803 AffineExpr expr; 1804 unsigned nSymbols = 0; 1805 1806 // AffineExpr for offset. 1807 // Static case. 1808 if (!ShapedType::isDynamic(offset)) { 1809 auto cst = getAffineConstantExpr(offset, context); 1810 expr = cst; 1811 } else { 1812 // Dynamic case, new symbol for the offset. 1813 auto sym = getAffineSymbolExpr(nSymbols++, context); 1814 expr = sym; 1815 } 1816 1817 // AffineExpr for strides. 1818 for (const auto &en : llvm::enumerate(strides)) { 1819 auto dim = en.index(); 1820 auto stride = en.value(); 1821 auto d = getAffineDimExpr(dim, context); 1822 AffineExpr mult; 1823 // Static case. 1824 if (!ShapedType::isDynamic(stride)) 1825 mult = getAffineConstantExpr(stride, context); 1826 else 1827 // Dynamic case, new symbol for each new stride. 1828 mult = getAffineSymbolExpr(nSymbols++, context); 1829 expr = expr + d * mult; 1830 } 1831 1832 return AffineMap::get(strides.size(), nSymbols, expr); 1833 } 1834