1 //===- BuiltinAttributes.h - MLIR Builtin Attribute Classes -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_BUILTINATTRIBUTES_H 10 #define MLIR_IR_BUILTINATTRIBUTES_H 11 12 #include "mlir/IR/BuiltinAttributeInterfaces.h" 13 #include "llvm/ADT/APFloat.h" 14 #include "llvm/ADT/Sequence.h" 15 #include <complex> 16 #include <optional> 17 18 namespace mlir { 19 class AffineMap; 20 class AsmResourceBlob; 21 class BoolAttr; 22 class BuiltinDialect; 23 class DenseIntElementsAttr; 24 template <typename T> 25 struct DialectResourceBlobHandle; 26 class FlatSymbolRefAttr; 27 class FunctionType; 28 class IntegerSet; 29 class IntegerType; 30 class Location; 31 class Operation; 32 class RankedTensorType; 33 34 namespace detail { 35 struct DenseIntOrFPElementsAttrStorage; 36 struct DenseStringElementsAttrStorage; 37 struct StringAttrStorage; 38 } // namespace detail 39 40 //===----------------------------------------------------------------------===// 41 // Elements Attributes 42 //===----------------------------------------------------------------------===// 43 44 namespace detail { 45 /// Pair of raw pointer and a boolean flag of whether the pointer holds a splat, 46 using DenseIterPtrAndSplat = std::pair<const char *, bool>; 47 48 /// Impl iterator for indexed DenseElementsAttr iterators that records a data 49 /// pointer and data index that is adjusted for the case of a splat attribute. 50 template <typename ConcreteT, typename T, typename PointerT = T *, 51 typename ReferenceT = T &> 52 class DenseElementIndexedIteratorImpl 53 : public llvm::indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T, 54 PointerT, ReferenceT> { 55 protected: DenseElementIndexedIteratorImpl(const char * data,bool isSplat,size_t dataIndex)56 DenseElementIndexedIteratorImpl(const char *data, bool isSplat, 57 size_t dataIndex) 58 : llvm::indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T, 59 PointerT, ReferenceT>({data, isSplat}, 60 dataIndex) {} 61 62 /// Return the current index for this iterator, adjusted for the case of a 63 /// splat. getDataIndex()64 ptrdiff_t getDataIndex() const { 65 bool isSplat = this->base.second; 66 return isSplat ? 0 : this->index; 67 } 68 69 /// Return the data base pointer. getData()70 const char *getData() const { return this->base.first; } 71 }; 72 73 /// Type trait detector that checks if a given type T is a complex type. 74 template <typename T> 75 struct is_complex_t : public std::false_type {}; 76 template <typename T> 77 struct is_complex_t<std::complex<T>> : public std::true_type {}; 78 } // namespace detail 79 80 /// An attribute that represents a reference to a dense vector or tensor 81 /// object. 82 class DenseElementsAttr : public Attribute { 83 public: 84 using Attribute::Attribute; 85 86 /// Allow implicit conversion to ElementsAttr. 87 operator ElementsAttr() const { return cast_if_present<ElementsAttr>(*this); } 88 /// Allow implicit conversion to TypedAttr. 89 operator TypedAttr() const { return ElementsAttr(*this); } 90 91 /// Type trait used to check if the given type T is a potentially valid C++ 92 /// floating point type that can be used to access the underlying element 93 /// types of a DenseElementsAttr. 94 template <typename T> 95 struct is_valid_cpp_fp_type { 96 /// The type is a valid floating point type if it is a builtin floating 97 /// point type, or is a potentially user defined floating point type. The 98 /// latter allows for supporting users that have custom types defined for 99 /// bfloat16/half/etc. 100 static constexpr bool value = llvm::is_one_of<T, float, double>::value || 101 (std::numeric_limits<T>::is_specialized && 102 !std::numeric_limits<T>::is_integer); 103 }; 104 105 /// Method for support type inquiry through isa, cast and dyn_cast. 106 static bool classof(Attribute attr); 107 108 /// Constructs a dense elements attribute from an array of element values. 109 /// Each element attribute value is expected to be an element of 'type'. 110 /// 'type' must be a vector or tensor with static shape. If the element of 111 /// `type` is non-integer/index/float it is assumed to be a string type. 112 static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values); 113 114 /// Constructs a dense integer elements attribute from an array of integer 115 /// or floating-point values. Each value is expected to be the same bitwidth 116 /// of the element type of 'type'. 'type' must be a vector or tensor with 117 /// static shape. 118 template <typename T, 119 typename = std::enable_if_t<std::numeric_limits<T>::is_integer || 120 is_valid_cpp_fp_type<T>::value>> 121 static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) { 122 const char *data = reinterpret_cast<const char *>(values.data()); 123 return getRawIntOrFloat( 124 type, ArrayRef<char>(data, values.size() * sizeof(T)), sizeof(T), 125 std::numeric_limits<T>::is_integer, std::numeric_limits<T>::is_signed); 126 } 127 128 /// Constructs a dense integer elements attribute from a single element. 129 template <typename T, 130 typename = std::enable_if_t<std::numeric_limits<T>::is_integer || 131 is_valid_cpp_fp_type<T>::value || 132 detail::is_complex_t<T>::value>> 133 static DenseElementsAttr get(const ShapedType &type, T value) { 134 return get(type, llvm::ArrayRef(value)); 135 } 136 137 /// Constructs a dense complex elements attribute from an array of complex 138 /// values. Each value is expected to be the same bitwidth of the element type 139 /// of 'type'. 'type' must be a vector or tensor with static shape. 140 template < 141 typename T, typename ElementT = typename T::value_type, 142 typename = std::enable_if_t<detail::is_complex_t<T>::value && 143 (std::numeric_limits<ElementT>::is_integer || 144 is_valid_cpp_fp_type<ElementT>::value)>> 145 static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) { 146 const char *data = reinterpret_cast<const char *>(values.data()); 147 return getRawComplex(type, ArrayRef<char>(data, values.size() * sizeof(T)), 148 sizeof(T), std::numeric_limits<ElementT>::is_integer, 149 std::numeric_limits<ElementT>::is_signed); 150 } 151 152 /// Overload of the above 'get' method that is specialized for boolean values. 153 static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values); 154 155 /// Overload of the above 'get' method that is specialized for StringRef 156 /// values. 157 static DenseElementsAttr get(ShapedType type, ArrayRef<StringRef> values); 158 159 /// Constructs a dense integer elements attribute from an array of APInt 160 /// values. Each APInt value is expected to have the same bitwidth as the 161 /// element type of 'type'. 'type' must be a vector or tensor with static 162 /// shape. 163 static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values); 164 165 /// Constructs a dense complex elements attribute from an array of APInt 166 /// values. Each APInt value is expected to have the same bitwidth as the 167 /// element type of 'type'. 'type' must be a vector or tensor with static 168 /// shape. 169 static DenseElementsAttr get(ShapedType type, 170 ArrayRef<std::complex<APInt>> values); 171 172 /// Constructs a dense float elements attribute from an array of APFloat 173 /// values. Each APFloat value is expected to have the same bitwidth as the 174 /// element type of 'type'. 'type' must be a vector or tensor with static 175 /// shape. 176 static DenseElementsAttr get(ShapedType type, ArrayRef<APFloat> values); 177 178 /// Constructs a dense complex elements attribute from an array of APFloat 179 /// values. Each APFloat value is expected to have the same bitwidth as the 180 /// element type of 'type'. 'type' must be a vector or tensor with static 181 /// shape. 182 static DenseElementsAttr get(ShapedType type, 183 ArrayRef<std::complex<APFloat>> values); 184 185 /// Construct a dense elements attribute for an initializer_list of values. 186 /// Each value is expected to be the same bitwidth of the element type of 187 /// 'type'. 'type' must be a vector or tensor with static shape. 188 template <typename T> 189 static DenseElementsAttr get(const ShapedType &type, 190 const std::initializer_list<T> &list) { 191 return get(type, ArrayRef<T>(list)); 192 } 193 194 /// Construct a dense elements attribute from a raw buffer representing the 195 /// data for this attribute. Users are encouraged to use one of the 196 /// constructors above, which provide more safeties. However, this 197 /// constructor is useful for tools which may want to interop and can 198 /// follow the precise definition. 199 /// 200 /// The format of the raw buffer is a densely packed array of values that 201 /// can be bitcast to the storage format of the element type specified. 202 /// Types that are not byte aligned will be: 203 /// - For bitwidth > 1: Rounded up to the next byte. 204 /// - For bitwidth = 1: Packed into 8bit bytes with bits corresponding to 205 /// the linear order of the shape type from MSB to LSB, padded to on the 206 /// right. 207 static DenseElementsAttr getFromRawBuffer(ShapedType type, 208 ArrayRef<char> rawBuffer); 209 210 /// Returns true if the given buffer is a valid raw buffer for the given type. 211 /// `detectedSplat` is set if the buffer is valid and represents a splat 212 /// buffer. The definition may be expanded over time, but currently, a 213 /// splat buffer is detected if: 214 /// - For >1bit: The buffer consists of a single element. 215 /// - For 1bit: The buffer consists of a single byte with value 0 or 255. 216 /// 217 /// User code should be prepared for additional, conformant patterns to be 218 /// identified as splats in the future. 219 static bool isValidRawBuffer(ShapedType type, ArrayRef<char> rawBuffer, 220 bool &detectedSplat); 221 222 //===--------------------------------------------------------------------===// 223 // Iterators 224 //===--------------------------------------------------------------------===// 225 226 /// The iterator range over the given iterator type T. 227 template <typename IteratorT> 228 using iterator_range_impl = detail::ElementsAttrRange<IteratorT>; 229 230 /// The iterator for the given element type T. 231 template <typename T, typename AttrT = DenseElementsAttr> 232 using iterator = decltype(std::declval<AttrT>().template value_begin<T>()); 233 /// The iterator range over the given element T. 234 template <typename T, typename AttrT = DenseElementsAttr> 235 using iterator_range = 236 decltype(std::declval<AttrT>().template getValues<T>()); 237 238 /// A utility iterator that allows walking over the internal Attribute values 239 /// of a DenseElementsAttr. 240 class AttributeElementIterator 241 : public llvm::indexed_accessor_iterator<AttributeElementIterator, 242 const void *, Attribute, 243 Attribute, Attribute> { 244 public: 245 /// Accesses the Attribute value at this iterator position. 246 Attribute operator*() const; 247 248 private: 249 friend DenseElementsAttr; 250 251 /// Constructs a new iterator. 252 AttributeElementIterator(DenseElementsAttr attr, size_t index); 253 }; 254 255 /// Iterator for walking raw element values of the specified type 'T', which 256 /// may be any c++ data type matching the stored representation: int32_t, 257 /// float, etc. 258 template <typename T> 259 class ElementIterator 260 : public detail::DenseElementIndexedIteratorImpl<ElementIterator<T>, 261 const T> { 262 public: 263 /// Accesses the raw value at this iterator position. 264 const T &operator*() const { 265 return reinterpret_cast<const T *>(this->getData())[this->getDataIndex()]; 266 } 267 268 private: 269 friend DenseElementsAttr; 270 271 /// Constructs a new iterator. 272 ElementIterator(const char *data, bool isSplat, size_t dataIndex) 273 : detail::DenseElementIndexedIteratorImpl<ElementIterator<T>, const T>( 274 data, isSplat, dataIndex) {} 275 }; 276 277 /// A utility iterator that allows walking over the internal bool values. 278 class BoolElementIterator 279 : public detail::DenseElementIndexedIteratorImpl<BoolElementIterator, 280 bool, bool, bool> { 281 public: 282 /// Accesses the bool value at this iterator position. 283 bool operator*() const; 284 285 private: 286 friend DenseElementsAttr; 287 288 /// Constructs a new iterator. 289 BoolElementIterator(DenseElementsAttr attr, size_t dataIndex); 290 }; 291 292 /// A utility iterator that allows walking over the internal raw APInt values. 293 class IntElementIterator 294 : public detail::DenseElementIndexedIteratorImpl<IntElementIterator, 295 APInt, APInt, APInt> { 296 public: 297 /// Accesses the raw APInt value at this iterator position. 298 APInt operator*() const; 299 300 private: 301 friend DenseElementsAttr; 302 303 /// Constructs a new iterator. 304 IntElementIterator(DenseElementsAttr attr, size_t dataIndex); 305 306 /// The bitwidth of the element type. 307 size_t bitWidth; 308 }; 309 310 /// A utility iterator that allows walking over the internal raw complex APInt 311 /// values. 312 class ComplexIntElementIterator 313 : public detail::DenseElementIndexedIteratorImpl< 314 ComplexIntElementIterator, std::complex<APInt>, std::complex<APInt>, 315 std::complex<APInt>> { 316 public: 317 /// Accesses the raw std::complex<APInt> value at this iterator position. 318 std::complex<APInt> operator*() const; 319 320 private: 321 friend DenseElementsAttr; 322 323 /// Constructs a new iterator. 324 ComplexIntElementIterator(DenseElementsAttr attr, size_t dataIndex); 325 326 /// The bitwidth of the element type. 327 size_t bitWidth; 328 }; 329 330 /// Iterator for walking over APFloat values. 331 class FloatElementIterator final 332 : public llvm::mapped_iterator_base<FloatElementIterator, 333 IntElementIterator, APFloat> { 334 public: 335 /// Map the element to the iterator result type. 336 APFloat mapElement(const APInt &value) const { 337 return APFloat(*smt, value); 338 } 339 340 private: 341 friend DenseElementsAttr; 342 343 /// Initializes the float element iterator to the specified iterator. 344 FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it) 345 : BaseT(it), smt(&smt) {} 346 347 /// The float semantics to use when constructing the APFloat. 348 const llvm::fltSemantics *smt; 349 }; 350 351 /// Iterator for walking over complex APFloat values. 352 class ComplexFloatElementIterator final 353 : public llvm::mapped_iterator_base<ComplexFloatElementIterator, 354 ComplexIntElementIterator, 355 std::complex<APFloat>> { 356 public: 357 /// Map the element to the iterator result type. 358 std::complex<APFloat> mapElement(const std::complex<APInt> &value) const { 359 return {APFloat(*smt, value.real()), APFloat(*smt, value.imag())}; 360 } 361 362 private: 363 friend DenseElementsAttr; 364 365 /// Initializes the float element iterator to the specified iterator. 366 ComplexFloatElementIterator(const llvm::fltSemantics &smt, 367 ComplexIntElementIterator it) 368 : BaseT(it), smt(&smt) {} 369 370 /// The float semantics to use when constructing the APFloat. 371 const llvm::fltSemantics *smt; 372 }; 373 374 //===--------------------------------------------------------------------===// 375 // Value Querying 376 //===--------------------------------------------------------------------===// 377 378 /// Returns true if this attribute corresponds to a splat, i.e. if all element 379 /// values are the same. 380 bool isSplat() const; 381 382 /// Return the splat value for this attribute. This asserts that the attribute 383 /// corresponds to a splat. 384 template <typename T> 385 std::enable_if_t<!std::is_base_of<Attribute, T>::value || 386 std::is_same<Attribute, T>::value, 387 T> 388 getSplatValue() const { 389 assert(isSplat() && "expected the attribute to be a splat"); 390 return *value_begin<T>(); 391 } 392 /// Return the splat value for derived attribute element types. 393 template <typename T> 394 std::enable_if_t<std::is_base_of<Attribute, T>::value && 395 !std::is_same<Attribute, T>::value, 396 T> 397 getSplatValue() const { 398 return llvm::cast<T>(getSplatValue<Attribute>()); 399 } 400 401 /// Try to get an iterator of the given type to the start of the held element 402 /// values. Return failure if the type cannot be iterated. 403 template <typename T> 404 auto try_value_begin() const { 405 auto range = tryGetValues<T>(); 406 using iterator = decltype(range->begin()); 407 return failed(range) ? FailureOr<iterator>(failure()) : range->begin(); 408 } 409 410 /// Try to get an iterator of the given type to the end of the held element 411 /// values. Return failure if the type cannot be iterated. 412 template <typename T> 413 auto try_value_end() const { 414 auto range = tryGetValues<T>(); 415 using iterator = decltype(range->begin()); 416 return failed(range) ? FailureOr<iterator>(failure()) : range->end(); 417 } 418 419 /// Return the held element values as a range of the given type. 420 template <typename T> 421 auto getValues() const { 422 auto range = tryGetValues<T>(); 423 assert(succeeded(range) && "element type cannot be iterated"); 424 return std::move(*range); 425 } 426 427 /// Get an iterator of the given type to the start of the held element values. 428 template <typename T> 429 auto value_begin() const { 430 return getValues<T>().begin(); 431 } 432 433 /// Get an iterator of the given type to the end of the held element values. 434 template <typename T> 435 auto value_end() const { 436 return getValues<T>().end(); 437 } 438 439 /// Try to get the held element values as a range of integer or floating-point 440 /// values. 441 template <typename T> 442 using IntFloatValueTemplateCheckT = 443 std::enable_if_t<(!std::is_same<T, bool>::value && 444 std::numeric_limits<T>::is_integer) || 445 is_valid_cpp_fp_type<T>::value>; 446 template <typename T, typename = IntFloatValueTemplateCheckT<T>> 447 FailureOr<iterator_range_impl<ElementIterator<T>>> tryGetValues() const { 448 if (!isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer, 449 std::numeric_limits<T>::is_signed)) 450 return failure(); 451 const char *rawData = getRawData().data(); 452 bool splat = isSplat(); 453 return iterator_range_impl<ElementIterator<T>>( 454 getType(), ElementIterator<T>(rawData, splat, 0), 455 ElementIterator<T>(rawData, splat, getNumElements())); 456 } 457 458 /// Try to get the held element values as a range of std::complex. 459 template <typename T, typename ElementT> 460 using ComplexValueTemplateCheckT = 461 std::enable_if_t<detail::is_complex_t<T>::value && 462 (std::numeric_limits<ElementT>::is_integer || 463 is_valid_cpp_fp_type<ElementT>::value)>; 464 template <typename T, typename ElementT = typename T::value_type, 465 typename = ComplexValueTemplateCheckT<T, ElementT>> 466 FailureOr<iterator_range_impl<ElementIterator<T>>> tryGetValues() const { 467 if (!isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer, 468 std::numeric_limits<ElementT>::is_signed)) 469 return failure(); 470 const char *rawData = getRawData().data(); 471 bool splat = isSplat(); 472 return iterator_range_impl<ElementIterator<T>>( 473 getType(), ElementIterator<T>(rawData, splat, 0), 474 ElementIterator<T>(rawData, splat, getNumElements())); 475 } 476 477 /// Try to get the held element values as a range of StringRef. 478 template <typename T> 479 using StringRefValueTemplateCheckT = 480 std::enable_if_t<std::is_same<T, StringRef>::value>; 481 template <typename T, typename = StringRefValueTemplateCheckT<T>> 482 FailureOr<iterator_range_impl<ElementIterator<StringRef>>> 483 tryGetValues() const { 484 auto stringRefs = getRawStringData(); 485 const char *ptr = reinterpret_cast<const char *>(stringRefs.data()); 486 bool splat = isSplat(); 487 return iterator_range_impl<ElementIterator<StringRef>>( 488 getType(), ElementIterator<StringRef>(ptr, splat, 0), 489 ElementIterator<StringRef>(ptr, splat, getNumElements())); 490 } 491 492 /// Try to get the held element values as a range of Attributes. 493 template <typename T> 494 using AttributeValueTemplateCheckT = 495 std::enable_if_t<std::is_same<T, Attribute>::value>; 496 template <typename T, typename = AttributeValueTemplateCheckT<T>> 497 FailureOr<iterator_range_impl<AttributeElementIterator>> 498 tryGetValues() const { 499 return iterator_range_impl<AttributeElementIterator>( 500 getType(), AttributeElementIterator(*this, 0), 501 AttributeElementIterator(*this, getNumElements())); 502 } 503 504 /// Try to get the held element values a range of T, where T is a derived 505 /// attribute type. 506 template <typename T> 507 using DerivedAttrValueTemplateCheckT = 508 std::enable_if_t<std::is_base_of<Attribute, T>::value && 509 !std::is_same<Attribute, T>::value>; 510 template <typename T> 511 struct DerivedAttributeElementIterator 512 : public llvm::mapped_iterator_base<DerivedAttributeElementIterator<T>, 513 AttributeElementIterator, T> { 514 using llvm::mapped_iterator_base<DerivedAttributeElementIterator<T>, 515 AttributeElementIterator, 516 T>::mapped_iterator_base; 517 518 /// Map the element to the iterator result type. 519 T mapElement(Attribute attr) const { return llvm::cast<T>(attr); } 520 }; 521 template <typename T, typename = DerivedAttrValueTemplateCheckT<T>> 522 FailureOr<iterator_range_impl<DerivedAttributeElementIterator<T>>> 523 tryGetValues() const { 524 using DerivedIterT = DerivedAttributeElementIterator<T>; 525 return iterator_range_impl<DerivedIterT>( 526 getType(), DerivedIterT(value_begin<Attribute>()), 527 DerivedIterT(value_end<Attribute>())); 528 } 529 530 /// Try to get the held element values as a range of bool. The element type of 531 /// this attribute must be of integer type of bitwidth 1. 532 template <typename T> 533 using BoolValueTemplateCheckT = 534 std::enable_if_t<std::is_same<T, bool>::value>; 535 template <typename T, typename = BoolValueTemplateCheckT<T>> 536 FailureOr<iterator_range_impl<BoolElementIterator>> tryGetValues() const { 537 if (!isValidBool()) 538 return failure(); 539 return iterator_range_impl<BoolElementIterator>( 540 getType(), BoolElementIterator(*this, 0), 541 BoolElementIterator(*this, getNumElements())); 542 } 543 544 /// Try to get the held element values as a range of APInts. The element type 545 /// of this attribute must be of integer type. 546 template <typename T> 547 using APIntValueTemplateCheckT = 548 std::enable_if_t<std::is_same<T, APInt>::value>; 549 template <typename T, typename = APIntValueTemplateCheckT<T>> 550 FailureOr<iterator_range_impl<IntElementIterator>> tryGetValues() const { 551 if (!getElementType().isIntOrIndex()) 552 return failure(); 553 return iterator_range_impl<IntElementIterator>(getType(), raw_int_begin(), 554 raw_int_end()); 555 } 556 557 /// Try to get the held element values as a range of complex APInts. The 558 /// element type of this attribute must be a complex of integer type. 559 template <typename T> 560 using ComplexAPIntValueTemplateCheckT = 561 std::enable_if_t<std::is_same<T, std::complex<APInt>>::value>; 562 template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>> 563 FailureOr<iterator_range_impl<ComplexIntElementIterator>> 564 tryGetValues() const { 565 return tryGetComplexIntValues(); 566 } 567 568 /// Try to get the held element values as a range of APFloat. The element type 569 /// of this attribute must be of float type. 570 template <typename T> 571 using APFloatValueTemplateCheckT = 572 std::enable_if_t<std::is_same<T, APFloat>::value>; 573 template <typename T, typename = APFloatValueTemplateCheckT<T>> 574 FailureOr<iterator_range_impl<FloatElementIterator>> tryGetValues() const { 575 return tryGetFloatValues(); 576 } 577 578 /// Try to get the held element values as a range of complex APFloat. The 579 /// element type of this attribute must be a complex of float type. 580 template <typename T> 581 using ComplexAPFloatValueTemplateCheckT = 582 std::enable_if_t<std::is_same<T, std::complex<APFloat>>::value>; 583 template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>> 584 FailureOr<iterator_range_impl<ComplexFloatElementIterator>> 585 tryGetValues() const { 586 return tryGetComplexFloatValues(); 587 } 588 589 /// Return the raw storage data held by this attribute. Users should generally 590 /// not use this directly, as the internal storage format is not always in the 591 /// form the user might expect. 592 ArrayRef<char> getRawData() const; 593 594 /// Return the raw StringRef data held by this attribute. 595 ArrayRef<StringRef> getRawStringData() const; 596 597 /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor 598 /// with static shape. 599 ShapedType getType() const; 600 601 /// Return the element type of this DenseElementsAttr. 602 Type getElementType() const; 603 604 /// Returns the number of elements held by this attribute. 605 int64_t getNumElements() const; 606 607 /// Returns the number of elements held by this attribute. 608 int64_t size() const { return getNumElements(); } 609 610 /// Returns if the number of elements held by this attribute is 0. 611 bool empty() const { return size() == 0; } 612 613 //===--------------------------------------------------------------------===// 614 // Mutation Utilities 615 //===--------------------------------------------------------------------===// 616 617 /// Return a new DenseElementsAttr that has the same data as the current 618 /// attribute, but has been reshaped to 'newType'. The new type must have the 619 /// same total number of elements as well as element type. 620 DenseElementsAttr reshape(ShapedType newType); 621 622 /// Return a new DenseElementsAttr that has the same data as the current 623 /// attribute, but with a different shape for a splat type. The new type must 624 /// have the same element type. 625 DenseElementsAttr resizeSplat(ShapedType newType); 626 627 /// Return a new DenseElementsAttr that has the same data as the current 628 /// attribute, but has bitcast elements to 'newElType'. The new type must have 629 /// the same bitwidth as the current element type. 630 DenseElementsAttr bitcast(Type newElType); 631 632 /// Generates a new DenseElementsAttr by mapping each int value to a new 633 /// underlying APInt. The new values can represent either an integer or float. 634 /// This underlying type must be an DenseIntElementsAttr. 635 DenseElementsAttr mapValues(Type newElementType, 636 function_ref<APInt(const APInt &)> mapping) const; 637 638 /// Generates a new DenseElementsAttr by mapping each float value to a new 639 /// underlying APInt. the new values can represent either an integer or float. 640 /// This underlying type must be an DenseFPElementsAttr. 641 DenseElementsAttr 642 mapValues(Type newElementType, 643 function_ref<APInt(const APFloat &)> mapping) const; 644 645 protected: 646 /// Iterators to various elements that require out-of-line definition. These 647 /// are hidden from the user to encourage consistent use of the 648 /// getValues/value_begin/value_end API. 649 IntElementIterator raw_int_begin() const { 650 return IntElementIterator(*this, 0); 651 } 652 IntElementIterator raw_int_end() const { 653 return IntElementIterator(*this, getNumElements()); 654 } 655 FailureOr<iterator_range_impl<ComplexIntElementIterator>> 656 tryGetComplexIntValues() const; 657 FailureOr<iterator_range_impl<FloatElementIterator>> 658 tryGetFloatValues() const; 659 FailureOr<iterator_range_impl<ComplexFloatElementIterator>> 660 tryGetComplexFloatValues() const; 661 662 /// Overload of the raw 'get' method that asserts that the given type is of 663 /// complex type. This method is used to verify type invariants that the 664 /// templatized 'get' method cannot. 665 static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef<char> data, 666 int64_t dataEltSize, bool isInt, 667 bool isSigned); 668 669 /// Overload of the raw 'get' method that asserts that the given type is of 670 /// integer or floating-point type. This method is used to verify type 671 /// invariants that the templatized 'get' method cannot. 672 static DenseElementsAttr getRawIntOrFloat(ShapedType type, 673 ArrayRef<char> data, 674 int64_t dataEltSize, bool isInt, 675 bool isSigned); 676 677 /// Check the information for a C++ data type, check if this type is valid for 678 /// the current attribute. This method is used to verify specific type 679 /// invariants that the templatized 'getValues' method cannot. 680 bool isValidBool() const { return getElementType().isInteger(1); } 681 bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const; 682 bool isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const; 683 }; 684 685 /// An attribute that represents a reference to a splat vector or tensor 686 /// constant, meaning all of the elements have the same value. 687 class SplatElementsAttr : public DenseElementsAttr { 688 public: 689 using DenseElementsAttr::DenseElementsAttr; 690 691 /// Method for support type inquiry through isa, cast and dyn_cast. 692 static bool classof(Attribute attr) { 693 auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(attr); 694 return denseAttr && denseAttr.isSplat(); 695 } 696 }; 697 698 //===----------------------------------------------------------------------===// 699 // DenseResourceElementsAttr 700 //===----------------------------------------------------------------------===// 701 702 using DenseResourceElementsHandle = DialectResourceBlobHandle<BuiltinDialect>; 703 704 } // namespace mlir 705 706 //===----------------------------------------------------------------------===// 707 // Tablegen Attribute Declarations 708 //===----------------------------------------------------------------------===// 709 710 #define GET_ATTRDEF_CLASSES 711 #include "mlir/IR/BuiltinAttributes.h.inc" 712 713 //===----------------------------------------------------------------------===// 714 // C++ Attribute Declarations 715 //===----------------------------------------------------------------------===// 716 717 namespace mlir { 718 //===----------------------------------------------------------------------===// 719 // DenseArrayAttr 720 721 namespace detail { 722 /// Base class for DenseArrayAttr that is instantiated and specialized for each 723 /// supported element type below. 724 template <typename T> 725 class DenseArrayAttrImpl : public DenseArrayAttr { 726 public: 727 using DenseArrayAttr::DenseArrayAttr; 728 729 /// Implicit conversion to ArrayRef<T>. 730 operator ArrayRef<T>() const; 731 ArrayRef<T> asArrayRef() const { return ArrayRef<T>{*this}; } 732 733 /// Random access to elements. 734 T operator[](std::size_t index) const { return asArrayRef()[index]; } 735 736 /// Builder from ArrayRef<T>. 737 static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef<T> content); 738 739 /// Print the short form `[42, 100, -1]` without any type prefix. 740 void print(AsmPrinter &printer) const; 741 void print(raw_ostream &os) const; 742 /// Print the short form `42, 100, -1` without any braces or type prefix. 743 void printWithoutBraces(raw_ostream &os) const; 744 745 /// Parse the short form `[42, 100, -1]` without any type prefix. 746 static Attribute parse(AsmParser &parser, Type type); 747 748 /// Parse the short form `42, 100, -1` without any type prefix or braces. 749 static Attribute parseWithoutBraces(AsmParser &parser, Type type); 750 751 /// Support for isa<>/cast<>. 752 static bool classof(Attribute attr); 753 }; 754 755 extern template class DenseArrayAttrImpl<bool>; 756 extern template class DenseArrayAttrImpl<int8_t>; 757 extern template class DenseArrayAttrImpl<int16_t>; 758 extern template class DenseArrayAttrImpl<int32_t>; 759 extern template class DenseArrayAttrImpl<int64_t>; 760 extern template class DenseArrayAttrImpl<float>; 761 extern template class DenseArrayAttrImpl<double>; 762 } // namespace detail 763 764 // Public name for all the supported DenseArrayAttr 765 using DenseBoolArrayAttr = detail::DenseArrayAttrImpl<bool>; 766 using DenseI8ArrayAttr = detail::DenseArrayAttrImpl<int8_t>; 767 using DenseI16ArrayAttr = detail::DenseArrayAttrImpl<int16_t>; 768 using DenseI32ArrayAttr = detail::DenseArrayAttrImpl<int32_t>; 769 using DenseI64ArrayAttr = detail::DenseArrayAttrImpl<int64_t>; 770 using DenseF32ArrayAttr = detail::DenseArrayAttrImpl<float>; 771 using DenseF64ArrayAttr = detail::DenseArrayAttrImpl<double>; 772 773 //===----------------------------------------------------------------------===// 774 // DenseResourceElementsAttr 775 776 namespace detail { 777 /// Base class for DenseResourceElementsAttr that is instantiated and 778 /// specialized for each supported element type below. 779 template <typename T> 780 class DenseResourceElementsAttrBase : public DenseResourceElementsAttr { 781 public: 782 using DenseResourceElementsAttr::DenseResourceElementsAttr; 783 784 /// A builder that inserts a new resource using the provided blob. The handle 785 /// of the inserted blob is used when building the attribute. The provided 786 /// `blobName` is used as a hint for the key of the new handle for the `blob` 787 /// resource, but may be changed if necessary to ensure uniqueness during 788 /// insertion. 789 static DenseResourceElementsAttrBase<T> 790 get(ShapedType type, StringRef blobName, AsmResourceBlob blob); 791 792 /// Return the data of this attribute as an ArrayRef<T> if it is present, 793 /// returns std::nullopt otherwise. 794 std::optional<ArrayRef<T>> tryGetAsArrayRef() const; 795 796 /// Support for isa<>/cast<>. 797 static bool classof(Attribute attr); 798 }; 799 800 extern template class DenseResourceElementsAttrBase<bool>; 801 extern template class DenseResourceElementsAttrBase<int8_t>; 802 extern template class DenseResourceElementsAttrBase<int16_t>; 803 extern template class DenseResourceElementsAttrBase<int32_t>; 804 extern template class DenseResourceElementsAttrBase<int64_t>; 805 extern template class DenseResourceElementsAttrBase<uint8_t>; 806 extern template class DenseResourceElementsAttrBase<uint16_t>; 807 extern template class DenseResourceElementsAttrBase<uint32_t>; 808 extern template class DenseResourceElementsAttrBase<uint64_t>; 809 extern template class DenseResourceElementsAttrBase<float>; 810 extern template class DenseResourceElementsAttrBase<double>; 811 } // namespace detail 812 813 // Public names for all the supported DenseResourceElementsAttr. 814 815 using DenseBoolResourceElementsAttr = 816 detail::DenseResourceElementsAttrBase<bool>; 817 using DenseI8ResourceElementsAttr = 818 detail::DenseResourceElementsAttrBase<int8_t>; 819 using DenseI16ResourceElementsAttr = 820 detail::DenseResourceElementsAttrBase<int16_t>; 821 using DenseI32ResourceElementsAttr = 822 detail::DenseResourceElementsAttrBase<int32_t>; 823 using DenseI64ResourceElementsAttr = 824 detail::DenseResourceElementsAttrBase<int64_t>; 825 using DenseUI8ResourceElementsAttr = 826 detail::DenseResourceElementsAttrBase<uint8_t>; 827 using DenseUI16ResourceElementsAttr = 828 detail::DenseResourceElementsAttrBase<uint16_t>; 829 using DenseUI32ResourceElementsAttr = 830 detail::DenseResourceElementsAttrBase<uint32_t>; 831 using DenseUI64ResourceElementsAttr = 832 detail::DenseResourceElementsAttrBase<uint64_t>; 833 using DenseF32ResourceElementsAttr = 834 detail::DenseResourceElementsAttrBase<float>; 835 using DenseF64ResourceElementsAttr = 836 detail::DenseResourceElementsAttrBase<double>; 837 838 //===----------------------------------------------------------------------===// 839 // BoolAttr 840 //===----------------------------------------------------------------------===// 841 842 /// Special case of IntegerAttr to represent boolean integers, i.e., signless i1 843 /// integers. 844 class BoolAttr : public Attribute { 845 public: 846 using Attribute::Attribute; 847 using ValueType = bool; 848 849 static BoolAttr get(MLIRContext *context, bool value); 850 851 /// Enable conversion to IntegerAttr and its interfaces. This uses conversion 852 /// vs. inheritance to avoid bringing in all of IntegerAttrs methods. 853 operator IntegerAttr() const { return IntegerAttr(impl); } 854 operator TypedAttr() const { return IntegerAttr(impl); } 855 856 /// Return the boolean value of this attribute. 857 bool getValue() const; 858 859 /// Methods for support type inquiry through isa, cast, and dyn_cast. 860 static bool classof(Attribute attr); 861 }; 862 863 //===----------------------------------------------------------------------===// 864 // FlatSymbolRefAttr 865 //===----------------------------------------------------------------------===// 866 867 /// A symbol reference with a reference path containing a single element. This 868 /// is used to refer to an operation within the current symbol table. 869 class FlatSymbolRefAttr : public SymbolRefAttr { 870 public: 871 using SymbolRefAttr::SymbolRefAttr; 872 using ValueType = StringRef; 873 874 /// Construct a symbol reference for the given value name. 875 static FlatSymbolRefAttr get(StringAttr value) { 876 return SymbolRefAttr::get(value); 877 } 878 static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value) { 879 return SymbolRefAttr::get(ctx, value); 880 } 881 882 /// Convenience getter for building a SymbolRefAttr based on an operation 883 /// that implements the SymbolTrait. 884 static FlatSymbolRefAttr get(Operation *symbol) { 885 return SymbolRefAttr::get(symbol); 886 } 887 888 /// Returns the name of the held symbol reference as a StringAttr. 889 StringAttr getAttr() const { return getRootReference(); } 890 891 /// Returns the name of the held symbol reference. 892 StringRef getValue() const { return getAttr().getValue(); } 893 894 /// Methods for support type inquiry through isa, cast, and dyn_cast. 895 static bool classof(Attribute attr) { 896 SymbolRefAttr refAttr = llvm::dyn_cast<SymbolRefAttr>(attr); 897 return refAttr && refAttr.getNestedReferences().empty(); 898 } 899 900 private: 901 using SymbolRefAttr::get; 902 using SymbolRefAttr::getNestedReferences; 903 }; 904 905 //===----------------------------------------------------------------------===// 906 // DenseFPElementsAttr 907 //===----------------------------------------------------------------------===// 908 909 /// An attribute that represents a reference to a dense float vector or tensor 910 /// object. Each element is stored as a double. 911 class DenseFPElementsAttr : public DenseIntOrFPElementsAttr { 912 public: 913 using iterator = DenseElementsAttr::FloatElementIterator; 914 915 using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr; 916 917 /// Get an instance of a DenseFPElementsAttr with the given arguments. This 918 /// simply wraps the DenseElementsAttr::get calls. 919 template <typename Arg> 920 static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg) { 921 return llvm::cast<DenseFPElementsAttr>( 922 DenseElementsAttr::get(type, llvm::ArrayRef(arg))); 923 } 924 template <typename T> 925 static DenseFPElementsAttr get(const ShapedType &type, 926 const std::initializer_list<T> &list) { 927 return llvm::cast<DenseFPElementsAttr>(DenseElementsAttr::get(type, list)); 928 } 929 930 /// Generates a new DenseElementsAttr by mapping each value attribute, and 931 /// constructing the DenseElementsAttr given the new element type. 932 DenseElementsAttr 933 mapValues(Type newElementType, 934 function_ref<APInt(const APFloat &)> mapping) const; 935 936 /// Iterator access to the float element values. 937 iterator begin() const { return tryGetFloatValues()->begin(); } 938 iterator end() const { return tryGetFloatValues()->end(); } 939 940 /// Method for supporting type inquiry through isa, cast and dyn_cast. 941 static bool classof(Attribute attr); 942 }; 943 944 //===----------------------------------------------------------------------===// 945 // DenseIntElementsAttr 946 //===----------------------------------------------------------------------===// 947 948 /// An attribute that represents a reference to a dense integer vector or tensor 949 /// object. 950 class DenseIntElementsAttr : public DenseIntOrFPElementsAttr { 951 public: 952 /// DenseIntElementsAttr iterates on APInt, so we can use the raw element 953 /// iterator directly. 954 using iterator = DenseElementsAttr::IntElementIterator; 955 956 using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr; 957 958 /// Get an instance of a DenseIntElementsAttr with the given arguments. This 959 /// simply wraps the DenseElementsAttr::get calls. 960 template <typename Arg> 961 static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg) { 962 return llvm::cast<DenseIntElementsAttr>( 963 DenseElementsAttr::get(type, llvm::ArrayRef(arg))); 964 } 965 template <typename T> 966 static DenseIntElementsAttr get(const ShapedType &type, 967 const std::initializer_list<T> &list) { 968 return llvm::cast<DenseIntElementsAttr>(DenseElementsAttr::get(type, list)); 969 } 970 971 /// Generates a new DenseElementsAttr by mapping each value attribute, and 972 /// constructing the DenseElementsAttr given the new element type. 973 DenseElementsAttr mapValues(Type newElementType, 974 function_ref<APInt(const APInt &)> mapping) const; 975 976 /// Iterator access to the integer element values. 977 iterator begin() const { return raw_int_begin(); } 978 iterator end() const { return raw_int_end(); } 979 980 /// Method for supporting type inquiry through isa, cast and dyn_cast. 981 static bool classof(Attribute attr); 982 }; 983 984 //===----------------------------------------------------------------------===// 985 // SparseElementsAttr 986 //===----------------------------------------------------------------------===// 987 988 template <typename T> 989 auto SparseElementsAttr::try_value_begin_impl(OverloadToken<T>) const 990 -> FailureOr<iterator<T>> { 991 auto zeroValue = getZeroValue<T>(); 992 auto valueIt = getValues().try_value_begin<T>(); 993 if (failed(valueIt)) 994 return failure(); 995 const std::vector<ptrdiff_t> flatSparseIndices(getFlattenedSparseIndices()); 996 std::function<T(ptrdiff_t)> mapFn = 997 [flatSparseIndices{flatSparseIndices}, valueIt{std::move(*valueIt)}, 998 zeroValue{std::move(zeroValue)}](ptrdiff_t index) { 999 // Try to map the current index to one of the sparse indices. 1000 for (unsigned i = 0, e = flatSparseIndices.size(); i != e; ++i) 1001 if (flatSparseIndices[i] == index) 1002 return *std::next(valueIt, i); 1003 // Otherwise, return the zero value. 1004 return zeroValue; 1005 }; 1006 return iterator<T>(llvm::seq<ptrdiff_t>(0, getNumElements()).begin(), mapFn); 1007 } 1008 1009 //===----------------------------------------------------------------------===// 1010 // DistinctAttr 1011 //===----------------------------------------------------------------------===// 1012 1013 namespace detail { 1014 struct DistinctAttrStorage; 1015 class DistinctAttributeUniquer; 1016 } // namespace detail 1017 1018 /// An attribute that associates a referenced attribute with a unique 1019 /// identifier. Every call to the create function allocates a new distinct 1020 /// attribute instance. The address of the attribute instance serves as a 1021 /// temporary identifier. Similar to the names of SSA values, the final 1022 /// identifiers are generated during pretty printing. This delayed numbering 1023 /// ensures the printed identifiers are deterministic even if multiple distinct 1024 /// attribute instances are created in-parallel. 1025 /// 1026 /// Examples: 1027 /// 1028 /// #distinct = distinct[0]<42.0 : f32> 1029 /// #distinct1 = distinct[1]<42.0 : f32> 1030 /// #distinct2 = distinct[2]<array<i32: 10, 42>> 1031 /// 1032 /// NOTE: The distinct attribute cannot be defined using ODS since it uses a 1033 /// custom distinct attribute uniquer that cannot be set from ODS. 1034 class DistinctAttr 1035 : public detail::StorageUserBase<DistinctAttr, Attribute, 1036 detail::DistinctAttrStorage, 1037 detail::DistinctAttributeUniquer> { 1038 public: 1039 using Base::Base; 1040 1041 /// Returns the referenced attribute. 1042 Attribute getReferencedAttr() const; 1043 1044 /// Creates a distinct attribute that associates a referenced attribute with a 1045 /// unique identifier. 1046 static DistinctAttr create(Attribute referencedAttr); 1047 1048 static constexpr StringLiteral name = "builtin.distinct"; 1049 }; 1050 1051 //===----------------------------------------------------------------------===// 1052 // StringAttr 1053 //===----------------------------------------------------------------------===// 1054 1055 /// Define comparisons for StringAttr against nullptr and itself to avoid the 1056 /// StringRef overloads from being chosen when not desirable. 1057 inline bool operator==(StringAttr lhs, std::nullptr_t) { return !lhs; } 1058 inline bool operator!=(StringAttr lhs, std::nullptr_t) { 1059 return static_cast<bool>(lhs); 1060 } 1061 inline bool operator==(StringAttr lhs, StringAttr rhs) { 1062 return (Attribute)lhs == (Attribute)rhs; 1063 } 1064 inline bool operator!=(StringAttr lhs, StringAttr rhs) { return !(lhs == rhs); } 1065 1066 /// Allow direct comparison with StringRef. 1067 inline bool operator==(StringAttr lhs, StringRef rhs) { 1068 return lhs.getValue() == rhs; 1069 } 1070 inline bool operator!=(StringAttr lhs, StringRef rhs) { return !(lhs == rhs); } 1071 inline bool operator==(StringRef lhs, StringAttr rhs) { 1072 return rhs.getValue() == lhs; 1073 } 1074 inline bool operator!=(StringRef lhs, StringAttr rhs) { return !(lhs == rhs); } 1075 1076 } // namespace mlir 1077 1078 //===----------------------------------------------------------------------===// 1079 // Attribute Utilities 1080 //===----------------------------------------------------------------------===// 1081 1082 namespace mlir { 1083 1084 /// Given a list of strides (in which ShapedType::kDynamic 1085 /// represents a dynamic value), return the single result AffineMap which 1086 /// represents the linearized strided layout map. Dimensions correspond to the 1087 /// offset followed by the strides in order. Symbols are inserted for each 1088 /// dynamic dimension in order. A stride is always positive. 1089 /// 1090 /// Examples: 1091 /// ========= 1092 /// 1093 /// 1. For offset: 0 strides: ?, ?, 1 return 1094 /// (i, j, k)[M, N]->(M * i + N * j + k) 1095 /// 1096 /// 2. For offset: 3 strides: 32, ?, 16 return 1097 /// (i, j, k)[M]->(3 + 32 * i + M * j + 16 * k) 1098 /// 1099 /// 3. For offset: ? strides: ?, ?, ? return 1100 /// (i, j, k)[off, M, N, P]->(off + M * i + N * j + P * k) 1101 AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset, 1102 MLIRContext *context); 1103 1104 } // namespace mlir 1105 1106 namespace llvm { 1107 1108 template <> 1109 struct DenseMapInfo<mlir::StringAttr> : public DenseMapInfo<mlir::Attribute> { 1110 static mlir::StringAttr getEmptyKey() { 1111 const void *pointer = llvm::DenseMapInfo<const void *>::getEmptyKey(); 1112 return mlir::StringAttr::getFromOpaquePointer(pointer); 1113 } 1114 static mlir::StringAttr getTombstoneKey() { 1115 const void *pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey(); 1116 return mlir::StringAttr::getFromOpaquePointer(pointer); 1117 } 1118 }; 1119 template <> 1120 struct PointerLikeTypeTraits<mlir::StringAttr> 1121 : public PointerLikeTypeTraits<mlir::Attribute> { 1122 static inline mlir::StringAttr getFromVoidPointer(void *p) { 1123 return mlir::StringAttr::getFromOpaquePointer(p); 1124 } 1125 }; 1126 1127 template <> 1128 struct PointerLikeTypeTraits<mlir::IntegerAttr> 1129 : public PointerLikeTypeTraits<mlir::Attribute> { 1130 static inline mlir::IntegerAttr getFromVoidPointer(void *p) { 1131 return mlir::IntegerAttr::getFromOpaquePointer(p); 1132 } 1133 }; 1134 1135 template <> 1136 struct PointerLikeTypeTraits<mlir::SymbolRefAttr> 1137 : public PointerLikeTypeTraits<mlir::Attribute> { 1138 static inline mlir::SymbolRefAttr getFromVoidPointer(void *ptr) { 1139 return mlir::SymbolRefAttr::getFromOpaquePointer(ptr); 1140 } 1141 }; 1142 1143 } // namespace llvm 1144 1145 #endif // MLIR_IR_BUILTINATTRIBUTES_H 1146