xref: /llvm-project/mlir/lib/IR/BuiltinAttributes.cpp (revision 72e8b9aeaa3f584f223bc59924812df69a09a48b)
1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BuiltinDialect.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/DialectResourceBlobManager.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/SymbolTable.h"
19 #include "mlir/IR/Types.h"
20 #include "llvm/ADT/APSInt.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/Endian.h"
25 #include <optional>
26 
27 #define DEBUG_TYPE "builtinattributes"
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 //===----------------------------------------------------------------------===//
33 /// Tablegen Attribute Definitions
34 //===----------------------------------------------------------------------===//
35 
36 #define GET_ATTRDEF_CLASSES
37 #include "mlir/IR/BuiltinAttributes.cpp.inc"
38 
39 //===----------------------------------------------------------------------===//
40 // BuiltinDialect
41 //===----------------------------------------------------------------------===//
42 
43 void BuiltinDialect::registerAttributes() {
44   addAttributes<
45 #define GET_ATTRDEF_LIST
46 #include "mlir/IR/BuiltinAttributes.cpp.inc"
47       >();
48   addAttributes<DistinctAttr>();
49 }
50 
51 //===----------------------------------------------------------------------===//
52 // DictionaryAttr
53 //===----------------------------------------------------------------------===//
54 
55 /// Helper function that does either an in place sort or sorts from source array
56 /// into destination. If inPlace then storage is both the source and the
57 /// destination, else value is the source and storage destination. Returns
58 /// whether source was sorted.
59 template <bool inPlace>
60 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
61                                SmallVectorImpl<NamedAttribute> &storage) {
62   // Specialize for the common case.
63   switch (value.size()) {
64   case 0:
65     // Zero already sorted.
66     if (!inPlace)
67       storage.clear();
68     break;
69   case 1:
70     // One already sorted but may need to be copied.
71     if (!inPlace)
72       storage.assign({value[0]});
73     break;
74   case 2: {
75     bool isSorted = value[0] < value[1];
76     if (inPlace) {
77       if (!isSorted)
78         std::swap(storage[0], storage[1]);
79     } else if (isSorted) {
80       storage.assign({value[0], value[1]});
81     } else {
82       storage.assign({value[1], value[0]});
83     }
84     return !isSorted;
85   }
86   default:
87     if (!inPlace)
88       storage.assign(value.begin(), value.end());
89     // Check to see they are sorted already.
90     bool isSorted = llvm::is_sorted(value);
91     // If not, do a general sort.
92     if (!isSorted)
93       llvm::array_pod_sort(storage.begin(), storage.end());
94     return !isSorted;
95   }
96   return false;
97 }
98 
99 /// Returns an entry with a duplicate name from the given sorted array of named
100 /// attributes. Returns std::nullopt if all elements have unique names.
101 static std::optional<NamedAttribute>
102 findDuplicateElement(ArrayRef<NamedAttribute> value) {
103   const std::optional<NamedAttribute> none{std::nullopt};
104   if (value.size() < 2)
105     return none;
106 
107   if (value.size() == 2)
108     return value[0].getName() == value[1].getName() ? value[0] : none;
109 
110   const auto *it = std::adjacent_find(value.begin(), value.end(),
111                                       [](NamedAttribute l, NamedAttribute r) {
112                                         return l.getName() == r.getName();
113                                       });
114   return it != value.end() ? *it : none;
115 }
116 
117 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
118                           SmallVectorImpl<NamedAttribute> &storage) {
119   bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
120   assert(!findDuplicateElement(storage) &&
121          "DictionaryAttr element names must be unique");
122   return isSorted;
123 }
124 
125 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
126   bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
127   assert(!findDuplicateElement(array) &&
128          "DictionaryAttr element names must be unique");
129   return isSorted;
130 }
131 
132 std::optional<NamedAttribute>
133 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
134                               bool isSorted) {
135   if (!isSorted)
136     dictionaryAttrSort</*inPlace=*/true>(array, array);
137   return findDuplicateElement(array);
138 }
139 
140 DictionaryAttr DictionaryAttr::get(MLIRContext *context,
141                                    ArrayRef<NamedAttribute> value) {
142   if (value.empty())
143     return DictionaryAttr::getEmpty(context);
144 
145   // We need to sort the element list to canonicalize it.
146   SmallVector<NamedAttribute, 8> storage;
147   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
148     value = storage;
149   assert(!findDuplicateElement(value) &&
150          "DictionaryAttr element names must be unique");
151   return Base::get(context, value);
152 }
153 /// Construct a dictionary with an array of values that is known to already be
154 /// sorted by name and uniqued.
155 DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
156                                              ArrayRef<NamedAttribute> value) {
157   if (value.empty())
158     return DictionaryAttr::getEmpty(context);
159   // Ensure that the attribute elements are unique and sorted.
160   assert(llvm::is_sorted(
161              value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) &&
162          "expected attribute values to be sorted");
163   assert(!findDuplicateElement(value) &&
164          "DictionaryAttr element names must be unique");
165   return Base::get(context, value);
166 }
167 
168 /// Return the specified attribute if present, null otherwise.
169 Attribute DictionaryAttr::get(StringRef name) const {
170   auto it = impl::findAttrSorted(begin(), end(), name);
171   return it.second ? it.first->getValue() : Attribute();
172 }
173 Attribute DictionaryAttr::get(StringAttr name) const {
174   auto it = impl::findAttrSorted(begin(), end(), name);
175   return it.second ? it.first->getValue() : Attribute();
176 }
177 
178 /// Return the specified named attribute if present, std::nullopt otherwise.
179 std::optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
180   auto it = impl::findAttrSorted(begin(), end(), name);
181   return it.second ? *it.first : std::optional<NamedAttribute>();
182 }
183 std::optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const {
184   auto it = impl::findAttrSorted(begin(), end(), name);
185   return it.second ? *it.first : std::optional<NamedAttribute>();
186 }
187 
188 /// Return whether the specified attribute is present.
189 bool DictionaryAttr::contains(StringRef name) const {
190   return impl::findAttrSorted(begin(), end(), name).second;
191 }
192 bool DictionaryAttr::contains(StringAttr name) const {
193   return impl::findAttrSorted(begin(), end(), name).second;
194 }
195 
196 DictionaryAttr::iterator DictionaryAttr::begin() const {
197   return getValue().begin();
198 }
199 DictionaryAttr::iterator DictionaryAttr::end() const {
200   return getValue().end();
201 }
202 size_t DictionaryAttr::size() const { return getValue().size(); }
203 
204 DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
205   return Base::get(context, ArrayRef<NamedAttribute>());
206 }
207 
208 //===----------------------------------------------------------------------===//
209 // StridedLayoutAttr
210 //===----------------------------------------------------------------------===//
211 
212 /// Prints a strided layout attribute.
213 void StridedLayoutAttr::print(llvm::raw_ostream &os) const {
214   auto printIntOrQuestion = [&](int64_t value) {
215     if (ShapedType::isDynamic(value))
216       os << "?";
217     else
218       os << value;
219   };
220 
221   os << "strided<[";
222   llvm::interleaveComma(getStrides(), os, printIntOrQuestion);
223   os << "]";
224 
225   if (getOffset() != 0) {
226     os << ", offset: ";
227     printIntOrQuestion(getOffset());
228   }
229   os << ">";
230 }
231 
232 /// Returns true if this layout is static, i.e. the strides and offset all have
233 /// a known value > 0.
234 bool StridedLayoutAttr::hasStaticLayout() const {
235   return !ShapedType::isDynamic(getOffset()) &&
236          !ShapedType::isDynamicShape(getStrides());
237 }
238 
239 /// Returns the strided layout as an affine map.
240 AffineMap StridedLayoutAttr::getAffineMap() const {
241   return makeStridedLinearLayoutMap(getStrides(), getOffset(), getContext());
242 }
243 
244 /// Checks that the type-agnostic strided layout invariants are satisfied.
245 LogicalResult
246 StridedLayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
247                           int64_t offset, ArrayRef<int64_t> strides) {
248   return success();
249 }
250 
251 /// Checks that the type-specific strided layout invariants are satisfied.
252 LogicalResult StridedLayoutAttr::verifyLayout(
253     ArrayRef<int64_t> shape,
254     function_ref<InFlightDiagnostic()> emitError) const {
255   if (shape.size() != getStrides().size())
256     return emitError() << "expected the number of strides to match the rank";
257 
258   return success();
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // StringAttr
263 //===----------------------------------------------------------------------===//
264 
265 StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) {
266   return Base::get(context, "", NoneType::get(context));
267 }
268 
269 /// Twine support for StringAttr.
270 StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) {
271   // Fast-path empty twine.
272   if (twine.isTriviallyEmpty())
273     return get(context);
274   SmallVector<char, 32> tempStr;
275   return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context));
276 }
277 
278 /// Twine support for StringAttr.
279 StringAttr StringAttr::get(const Twine &twine, Type type) {
280   SmallVector<char, 32> tempStr;
281   return Base::get(type.getContext(), twine.toStringRef(tempStr), type);
282 }
283 
284 StringRef StringAttr::getValue() const { return getImpl()->value; }
285 
286 Type StringAttr::getType() const { return getImpl()->type; }
287 
288 Dialect *StringAttr::getReferencedDialect() const {
289   return getImpl()->referencedDialect;
290 }
291 
292 //===----------------------------------------------------------------------===//
293 // FloatAttr
294 //===----------------------------------------------------------------------===//
295 
296 double FloatAttr::getValueAsDouble() const {
297   return getValueAsDouble(getValue());
298 }
299 double FloatAttr::getValueAsDouble(APFloat value) {
300   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
301     bool losesInfo = false;
302     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
303                   &losesInfo);
304   }
305   return value.convertToDouble();
306 }
307 
308 LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
309                                 Type type, APFloat value) {
310   // Verify that the type is correct.
311   if (!llvm::isa<FloatType>(type))
312     return emitError() << "expected floating point type";
313 
314   // Verify that the type semantics match that of the value.
315   if (&llvm::cast<FloatType>(type).getFloatSemantics() !=
316       &value.getSemantics()) {
317     return emitError()
318            << "FloatAttr type doesn't match the type implied by its value";
319   }
320   return success();
321 }
322 
323 //===----------------------------------------------------------------------===//
324 // SymbolRefAttr
325 //===----------------------------------------------------------------------===//
326 
327 SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
328                                  ArrayRef<FlatSymbolRefAttr> nestedRefs) {
329   return get(StringAttr::get(ctx, value), nestedRefs);
330 }
331 
332 FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
333   return llvm::cast<FlatSymbolRefAttr>(get(ctx, value, {}));
334 }
335 
336 FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
337   return llvm::cast<FlatSymbolRefAttr>(get(value, {}));
338 }
339 
340 FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) {
341   auto symName =
342       symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
343   assert(symName && "value does not have a valid symbol name");
344   return SymbolRefAttr::get(symName);
345 }
346 
347 StringAttr SymbolRefAttr::getLeafReference() const {
348   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
349   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // IntegerAttr
354 //===----------------------------------------------------------------------===//
355 
356 int64_t IntegerAttr::getInt() const {
357   assert((getType().isIndex() || getType().isSignlessInteger()) &&
358          "must be signless integer");
359   return getValue().getSExtValue();
360 }
361 
362 int64_t IntegerAttr::getSInt() const {
363   assert(getType().isSignedInteger() && "must be signed integer");
364   return getValue().getSExtValue();
365 }
366 
367 uint64_t IntegerAttr::getUInt() const {
368   assert(getType().isUnsignedInteger() && "must be unsigned integer");
369   return getValue().getZExtValue();
370 }
371 
372 /// Return the value as an APSInt which carries the signed from the type of
373 /// the attribute.  This traps on signless integers types!
374 APSInt IntegerAttr::getAPSInt() const {
375   assert(!getType().isSignlessInteger() &&
376          "Signless integers don't carry a sign for APSInt");
377   return APSInt(getValue(), getType().isUnsignedInteger());
378 }
379 
380 LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
381                                   Type type, APInt value) {
382   if (IntegerType integerType = llvm::dyn_cast<IntegerType>(type)) {
383     if (integerType.getWidth() != value.getBitWidth())
384       return emitError() << "integer type bit width (" << integerType.getWidth()
385                          << ") doesn't match value bit width ("
386                          << value.getBitWidth() << ")";
387     return success();
388   }
389   if (llvm::isa<IndexType>(type)) {
390     if (value.getBitWidth() != IndexType::kInternalStorageBitWidth)
391       return emitError()
392              << "value bit width (" << value.getBitWidth()
393              << ") doesn't match index type internal storage bit width ("
394              << IndexType::kInternalStorageBitWidth << ")";
395     return success();
396   }
397   return emitError() << "expected integer or index type";
398 }
399 
400 BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
401   auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
402   return llvm::cast<BoolAttr>(attr);
403 }
404 
405 //===----------------------------------------------------------------------===//
406 // BoolAttr
407 //===----------------------------------------------------------------------===//
408 
409 bool BoolAttr::getValue() const {
410   auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
411   return storage->value.getBoolValue();
412 }
413 
414 bool BoolAttr::classof(Attribute attr) {
415   IntegerAttr intAttr = llvm::dyn_cast<IntegerAttr>(attr);
416   return intAttr && intAttr.getType().isSignlessInteger(1);
417 }
418 
419 //===----------------------------------------------------------------------===//
420 // OpaqueAttr
421 //===----------------------------------------------------------------------===//
422 
423 LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
424                                  StringAttr dialect, StringRef attrData,
425                                  Type type) {
426   if (!Dialect::isValidNamespace(dialect.strref()))
427     return emitError() << "invalid dialect namespace '" << dialect << "'";
428 
429   // Check that the dialect is actually registered.
430   MLIRContext *context = dialect.getContext();
431   if (!context->allowsUnregisteredDialects() &&
432       !context->getLoadedDialect(dialect.strref())) {
433     return emitError()
434            << "#" << dialect << "<\"" << attrData << "\"> : " << type
435            << " attribute created with unregistered dialect. If this is "
436               "intended, please call allowUnregisteredDialects() on the "
437               "MLIRContext, or use -allow-unregistered-dialect with "
438               "the MLIR opt tool used";
439   }
440 
441   return success();
442 }
443 
444 //===----------------------------------------------------------------------===//
445 // DenseElementsAttr Utilities
446 //===----------------------------------------------------------------------===//
447 
448 const char DenseIntOrFPElementsAttrStorage::kSplatTrue = ~0;
449 const char DenseIntOrFPElementsAttrStorage::kSplatFalse = 0;
450 
451 /// Get the bitwidth of a dense element type within the buffer.
452 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
453 static size_t getDenseElementStorageWidth(size_t origWidth) {
454   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
455 }
456 static size_t getDenseElementStorageWidth(Type elementType) {
457   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
458 }
459 
460 /// Set a bit to a specific value.
461 static void setBit(char *rawData, size_t bitPos, bool value) {
462   if (value)
463     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
464   else
465     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
466 }
467 
468 /// Return the value of the specified bit.
469 static bool getBit(const char *rawData, size_t bitPos) {
470   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
471 }
472 
473 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
474 /// BE format.
475 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
476                                          char *result) {
477   assert(llvm::endianness::native == llvm::endianness::big);
478   assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
479 
480   // Copy the words filled with data.
481   // For example, when `value` has 2 words, the first word is filled with data.
482   // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
483   size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
484   std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
485               numFilledWords, result);
486   // Convert last word of APInt to LE format and store it in char
487   // array(`valueLE`).
488   // ex. last word of `value` (BE): |------ij|  ==> `valueLE` (LE): |ji------|
489   size_t lastWordPos = numFilledWords;
490   SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
491   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
492       reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
493       valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
494   // Extract actual APInt data from `valueLE`, convert endianness to BE format,
495   // and store it in `result`.
496   // ex. `valueLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|ij|
497   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
498       valueLE.begin(), result + lastWordPos,
499       (numBytes - lastWordPos) * CHAR_BIT, 1);
500 }
501 
502 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
503 /// format.
504 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
505                                          APInt &result) {
506   assert(llvm::endianness::native == llvm::endianness::big);
507   assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
508 
509   // Copy the data that fills the word of `result` from `inArray`.
510   // For example, when `result` has 2 words, the first word will be filled with
511   // data. So, the first 8 bytes are copied from `inArray` here.
512   // `inArray` (10 bytes, BE): |abcdefgh|ij|
513   //                     ==> `result` (2 words, BE): |abcdefgh|--------|
514   size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
515   std::copy_n(
516       inArray, numFilledWords,
517       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
518 
519   // Convert array data which will be last word of `result` to LE format, and
520   // store it in char array(`inArrayLE`).
521   // ex. `inArray` (last two bytes, BE): |ij|  ==> `inArrayLE` (LE): |ji------|
522   size_t lastWordPos = numFilledWords;
523   SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
524   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
525       inArray + lastWordPos, inArrayLE.begin(),
526       (numBytes - lastWordPos) * CHAR_BIT, 1);
527 
528   // Convert `inArrayLE` to BE format, and store it in last word of `result`.
529   // ex. `inArrayLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|------ij|
530   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
531       inArrayLE.begin(),
532       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
533           lastWordPos,
534       APInt::APINT_BITS_PER_WORD, 1);
535 }
536 
537 /// Writes value to the bit position `bitPos` in array `rawData`.
538 static void writeBits(char *rawData, size_t bitPos, APInt value) {
539   size_t bitWidth = value.getBitWidth();
540 
541   // If the bitwidth is 1 we just toggle the specific bit.
542   if (bitWidth == 1)
543     return setBit(rawData, bitPos, value.isOne());
544 
545   // Otherwise, the bit position is guaranteed to be byte aligned.
546   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
547   if (llvm::endianness::native == llvm::endianness::big) {
548     // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
549     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
550     // work correctly in BE format.
551     // ex. `value` (2 words including 10 bytes)
552     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------|
553     copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
554                                  rawData + (bitPos / CHAR_BIT));
555   } else {
556     std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
557                 llvm::divideCeil(bitWidth, CHAR_BIT),
558                 rawData + (bitPos / CHAR_BIT));
559   }
560 }
561 
562 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
563 /// `rawData`.
564 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
565   // Handle a boolean bit position.
566   if (bitWidth == 1)
567     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
568 
569   // Otherwise, the bit position must be 8-bit aligned.
570   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
571   APInt result(bitWidth, 0);
572   if (llvm::endianness::native == llvm::endianness::big) {
573     // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
574     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
575     // work correctly in BE format.
576     // ex. `result` (2 words including 10 bytes)
577     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------| This function
578     copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
579                                  llvm::divideCeil(bitWidth, CHAR_BIT), result);
580   } else {
581     std::copy_n(rawData + (bitPos / CHAR_BIT),
582                 llvm::divideCeil(bitWidth, CHAR_BIT),
583                 const_cast<char *>(
584                     reinterpret_cast<const char *>(result.getRawData())));
585   }
586   return result;
587 }
588 
589 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
590 /// the same element count as 'type'.
591 template <typename Values>
592 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
593   return (values.size() == 1) ||
594          (type.getNumElements() == static_cast<int64_t>(values.size()));
595 }
596 
597 //===----------------------------------------------------------------------===//
598 // DenseElementsAttr Iterators
599 //===----------------------------------------------------------------------===//
600 
601 //===----------------------------------------------------------------------===//
602 // AttributeElementIterator
603 
604 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
605     DenseElementsAttr attr, size_t index)
606     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
607                                       Attribute, Attribute, Attribute>(
608           attr.getAsOpaquePointer(), index) {}
609 
610 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
611   auto owner = llvm::cast<DenseElementsAttr>(getFromOpaquePointer(base));
612   Type eltTy = owner.getElementType();
613   if (llvm::dyn_cast<IntegerType>(eltTy))
614     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
615   if (llvm::isa<IndexType>(eltTy))
616     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
617   if (auto floatEltTy = llvm::dyn_cast<FloatType>(eltTy)) {
618     IntElementIterator intIt(owner, index);
619     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
620     return FloatAttr::get(eltTy, *floatIt);
621   }
622   if (auto complexTy = llvm::dyn_cast<ComplexType>(eltTy)) {
623     auto complexEltTy = complexTy.getElementType();
624     ComplexIntElementIterator complexIntIt(owner, index);
625     if (llvm::isa<IntegerType>(complexEltTy)) {
626       auto value = *complexIntIt;
627       auto real = IntegerAttr::get(complexEltTy, value.real());
628       auto imag = IntegerAttr::get(complexEltTy, value.imag());
629       return ArrayAttr::get(complexTy.getContext(),
630                             ArrayRef<Attribute>{real, imag});
631     }
632 
633     ComplexFloatElementIterator complexFloatIt(
634         llvm::cast<FloatType>(complexEltTy).getFloatSemantics(), complexIntIt);
635     auto value = *complexFloatIt;
636     auto real = FloatAttr::get(complexEltTy, value.real());
637     auto imag = FloatAttr::get(complexEltTy, value.imag());
638     return ArrayAttr::get(complexTy.getContext(),
639                           ArrayRef<Attribute>{real, imag});
640   }
641   if (llvm::isa<DenseStringElementsAttr>(owner)) {
642     ArrayRef<StringRef> vals = owner.getRawStringData();
643     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
644   }
645   llvm_unreachable("unexpected element type");
646 }
647 
648 //===----------------------------------------------------------------------===//
649 // BoolElementIterator
650 
651 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
652     DenseElementsAttr attr, size_t dataIndex)
653     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
654           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
655 
656 bool DenseElementsAttr::BoolElementIterator::operator*() const {
657   return getBit(getData(), getDataIndex());
658 }
659 
660 //===----------------------------------------------------------------------===//
661 // IntElementIterator
662 
663 DenseElementsAttr::IntElementIterator::IntElementIterator(
664     DenseElementsAttr attr, size_t dataIndex)
665     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
666           attr.getRawData().data(), attr.isSplat(), dataIndex),
667       bitWidth(getDenseElementBitWidth(attr.getElementType())) {}
668 
669 APInt DenseElementsAttr::IntElementIterator::operator*() const {
670   return readBits(getData(),
671                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
672                   bitWidth);
673 }
674 
675 //===----------------------------------------------------------------------===//
676 // ComplexIntElementIterator
677 
678 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
679     DenseElementsAttr attr, size_t dataIndex)
680     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
681                                       std::complex<APInt>, std::complex<APInt>,
682                                       std::complex<APInt>>(
683           attr.getRawData().data(), attr.isSplat(), dataIndex) {
684   auto complexType = llvm::cast<ComplexType>(attr.getElementType());
685   bitWidth = getDenseElementBitWidth(complexType.getElementType());
686 }
687 
688 std::complex<APInt>
689 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
690   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
691   size_t offset = getDataIndex() * storageWidth * 2;
692   return {readBits(getData(), offset, bitWidth),
693           readBits(getData(), offset + storageWidth, bitWidth)};
694 }
695 
696 //===----------------------------------------------------------------------===//
697 // DenseArrayAttr
698 //===----------------------------------------------------------------------===//
699 
700 LogicalResult
701 DenseArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError,
702                        Type elementType, int64_t size, ArrayRef<char> rawData) {
703   if (!elementType.isIntOrIndexOrFloat())
704     return emitError() << "expected integer or floating point element type";
705   int64_t dataSize = rawData.size();
706   int64_t elementSize =
707       llvm::divideCeil(elementType.getIntOrFloatBitWidth(), CHAR_BIT);
708   if (size * elementSize != dataSize) {
709     return emitError() << "expected data size (" << size << " elements, "
710                        << elementSize
711                        << " bytes each) does not match: " << dataSize
712                        << " bytes";
713   }
714   return success();
715 }
716 
717 namespace {
718 /// Instantiations of this class provide utilities for interacting with native
719 /// data types in the context of DenseArrayAttr.
720 template <size_t width,
721           IntegerType::SignednessSemantics signedness = IntegerType::Signless>
722 struct DenseArrayAttrIntUtil {
723   static bool checkElementType(Type eltType) {
724     auto type = llvm::dyn_cast<IntegerType>(eltType);
725     if (!type || type.getWidth() != width)
726       return false;
727     return type.getSignedness() == signedness;
728   }
729 
730   static Type getElementType(MLIRContext *ctx) {
731     return IntegerType::get(ctx, width, signedness);
732   }
733 
734   template <typename T>
735   static void printElement(raw_ostream &os, T value) {
736     os << value;
737   }
738 
739   template <typename T>
740   static ParseResult parseElement(AsmParser &parser, T &value) {
741     return parser.parseInteger(value);
742   }
743 };
744 template <typename T>
745 struct DenseArrayAttrUtil;
746 
747 /// Specialization for boolean elements to print 'true' and 'false' literals for
748 /// elements.
749 template <>
750 struct DenseArrayAttrUtil<bool> : public DenseArrayAttrIntUtil<1> {
751   static void printElement(raw_ostream &os, bool value) {
752     os << (value ? "true" : "false");
753   }
754 };
755 
756 /// Specialization for 8-bit integers to ensure values are printed as integers
757 /// and not characters.
758 template <>
759 struct DenseArrayAttrUtil<int8_t> : public DenseArrayAttrIntUtil<8> {
760   static void printElement(raw_ostream &os, int8_t value) {
761     os << static_cast<int>(value);
762   }
763 };
764 template <>
765 struct DenseArrayAttrUtil<int16_t> : public DenseArrayAttrIntUtil<16> {};
766 template <>
767 struct DenseArrayAttrUtil<int32_t> : public DenseArrayAttrIntUtil<32> {};
768 template <>
769 struct DenseArrayAttrUtil<int64_t> : public DenseArrayAttrIntUtil<64> {};
770 
771 /// Specialization for 32-bit floats.
772 template <>
773 struct DenseArrayAttrUtil<float> {
774   static bool checkElementType(Type eltType) { return eltType.isF32(); }
775   static Type getElementType(MLIRContext *ctx) { return Float32Type::get(ctx); }
776   static void printElement(raw_ostream &os, float value) { os << value; }
777 
778   /// Parse a double and cast it to a float.
779   static ParseResult parseElement(AsmParser &parser, float &value) {
780     double doubleVal;
781     if (parser.parseFloat(doubleVal))
782       return failure();
783     value = doubleVal;
784     return success();
785   }
786 };
787 
788 /// Specialization for 64-bit floats.
789 template <>
790 struct DenseArrayAttrUtil<double> {
791   static bool checkElementType(Type eltType) { return eltType.isF64(); }
792   static Type getElementType(MLIRContext *ctx) { return Float64Type::get(ctx); }
793   static void printElement(raw_ostream &os, float value) { os << value; }
794   static ParseResult parseElement(AsmParser &parser, double &value) {
795     return parser.parseFloat(value);
796   }
797 };
798 } // namespace
799 
800 template <typename T>
801 void DenseArrayAttrImpl<T>::print(AsmPrinter &printer) const {
802   print(printer.getStream());
803 }
804 
805 template <typename T>
806 void DenseArrayAttrImpl<T>::printWithoutBraces(raw_ostream &os) const {
807   llvm::interleaveComma(asArrayRef(), os, [&](T value) {
808     DenseArrayAttrUtil<T>::printElement(os, value);
809   });
810 }
811 
812 template <typename T>
813 void DenseArrayAttrImpl<T>::print(raw_ostream &os) const {
814   os << "[";
815   printWithoutBraces(os);
816   os << "]";
817 }
818 
819 /// Parse a DenseArrayAttr without the braces: `1, 2, 3`
820 template <typename T>
821 Attribute DenseArrayAttrImpl<T>::parseWithoutBraces(AsmParser &parser,
822                                                     Type odsType) {
823   SmallVector<T> data;
824   if (failed(parser.parseCommaSeparatedList([&]() {
825         T value;
826         if (DenseArrayAttrUtil<T>::parseElement(parser, value))
827           return failure();
828         data.push_back(value);
829         return success();
830       })))
831     return {};
832   return get(parser.getContext(), data);
833 }
834 
835 /// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
836 template <typename T>
837 Attribute DenseArrayAttrImpl<T>::parse(AsmParser &parser, Type odsType) {
838   if (parser.parseLSquare())
839     return {};
840   // Handle empty list case.
841   if (succeeded(parser.parseOptionalRSquare()))
842     return get(parser.getContext(), {});
843   Attribute result = parseWithoutBraces(parser, odsType);
844   if (parser.parseRSquare())
845     return {};
846   return result;
847 }
848 
849 /// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
850 template <typename T>
851 DenseArrayAttrImpl<T>::operator ArrayRef<T>() const {
852   ArrayRef<char> raw = getRawData();
853   assert((raw.size() % sizeof(T)) == 0);
854   return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
855                      raw.size() / sizeof(T));
856 }
857 
858 /// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
859 template <typename T>
860 DenseArrayAttrImpl<T> DenseArrayAttrImpl<T>::get(MLIRContext *context,
861                                                  ArrayRef<T> content) {
862   Type elementType = DenseArrayAttrUtil<T>::getElementType(context);
863   auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
864                                  content.size() * sizeof(T));
865   return llvm::cast<DenseArrayAttrImpl<T>>(
866       Base::get(context, elementType, content.size(), rawArray));
867 }
868 
869 template <typename T>
870 bool DenseArrayAttrImpl<T>::classof(Attribute attr) {
871   if (auto denseArray = llvm::dyn_cast<DenseArrayAttr>(attr))
872     return DenseArrayAttrUtil<T>::checkElementType(denseArray.getElementType());
873   return false;
874 }
875 
876 namespace mlir {
877 namespace detail {
878 // Explicit instantiation for all the supported DenseArrayAttr.
879 template class DenseArrayAttrImpl<bool>;
880 template class DenseArrayAttrImpl<int8_t>;
881 template class DenseArrayAttrImpl<int16_t>;
882 template class DenseArrayAttrImpl<int32_t>;
883 template class DenseArrayAttrImpl<int64_t>;
884 template class DenseArrayAttrImpl<float>;
885 template class DenseArrayAttrImpl<double>;
886 } // namespace detail
887 } // namespace mlir
888 
889 //===----------------------------------------------------------------------===//
890 // DenseElementsAttr
891 //===----------------------------------------------------------------------===//
892 
893 /// Method for support type inquiry through isa, cast and dyn_cast.
894 bool DenseElementsAttr::classof(Attribute attr) {
895   return llvm::isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(attr);
896 }
897 
898 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
899                                          ArrayRef<Attribute> values) {
900   assert(hasSameElementsOrSplat(type, values));
901 
902   Type eltType = type.getElementType();
903 
904   // Take care complex type case first.
905   if (auto complexType = llvm::dyn_cast<ComplexType>(eltType)) {
906     if (complexType.getElementType().isIntOrIndex()) {
907       SmallVector<std::complex<APInt>> complexValues;
908       complexValues.reserve(values.size());
909       for (Attribute attr : values) {
910         assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex");
911         auto arrayAttr = llvm::cast<ArrayAttr>(attr);
912         assert(arrayAttr.size() == 2 && "expected 2 element for complex");
913         auto attr0 = arrayAttr[0];
914         auto attr1 = arrayAttr[1];
915         complexValues.push_back(
916             std::complex<APInt>(llvm::cast<IntegerAttr>(attr0).getValue(),
917                                 llvm::cast<IntegerAttr>(attr1).getValue()));
918       }
919       return DenseElementsAttr::get(type, complexValues);
920     }
921     // Must be float.
922     SmallVector<std::complex<APFloat>> complexValues;
923     complexValues.reserve(values.size());
924     for (Attribute attr : values) {
925       assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex");
926       auto arrayAttr = llvm::cast<ArrayAttr>(attr);
927       assert(arrayAttr.size() == 2 && "expected 2 element for complex");
928       auto attr0 = arrayAttr[0];
929       auto attr1 = arrayAttr[1];
930       complexValues.push_back(
931           std::complex<APFloat>(llvm::cast<FloatAttr>(attr0).getValue(),
932                                 llvm::cast<FloatAttr>(attr1).getValue()));
933     }
934     return DenseElementsAttr::get(type, complexValues);
935   }
936 
937   // If the element type is not based on int/float/index, assume it is a string
938   // type.
939   if (!eltType.isIntOrIndexOrFloat()) {
940     SmallVector<StringRef, 8> stringValues;
941     stringValues.reserve(values.size());
942     for (Attribute attr : values) {
943       assert(llvm::isa<StringAttr>(attr) &&
944              "expected string value for non integer/index/float element");
945       stringValues.push_back(llvm::cast<StringAttr>(attr).getValue());
946     }
947     return get(type, stringValues);
948   }
949 
950   // Otherwise, get the raw storage width to use for the allocation.
951   size_t bitWidth = getDenseElementBitWidth(eltType);
952   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
953 
954   // Compress the attribute values into a character buffer.
955   SmallVector<char, 8> data(
956       llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT));
957   APInt intVal;
958   for (unsigned i = 0, e = values.size(); i < e; ++i) {
959     if (auto floatAttr = llvm::dyn_cast<FloatAttr>(values[i])) {
960       assert(floatAttr.getType() == eltType &&
961              "expected float attribute type to equal element type");
962       intVal = floatAttr.getValue().bitcastToAPInt();
963     } else {
964       auto intAttr = llvm::cast<IntegerAttr>(values[i]);
965       assert(intAttr.getType() == eltType &&
966              "expected integer attribute type to equal element type");
967       intVal = intAttr.getValue();
968     }
969 
970     assert(intVal.getBitWidth() == bitWidth &&
971            "expected value to have same bitwidth as element type");
972     writeBits(data.data(), i * storageBitWidth, intVal);
973   }
974 
975   // Handle the special encoding of splat of bool.
976   if (values.size() == 1 && eltType.isInteger(1))
977     data[0] = data[0] ? -1 : 0;
978 
979   return DenseIntOrFPElementsAttr::getRaw(type, data);
980 }
981 
982 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
983                                          ArrayRef<bool> values) {
984   assert(hasSameElementsOrSplat(type, values));
985   assert(type.getElementType().isInteger(1));
986 
987   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
988 
989   if (!values.empty()) {
990     bool isSplat = true;
991     bool firstValue = values[0];
992     for (int i = 0, e = values.size(); i != e; ++i) {
993       isSplat &= values[i] == firstValue;
994       setBit(buff.data(), i, values[i]);
995     }
996 
997     // Splat of bool is encoded as a byte with all-ones in it.
998     if (isSplat) {
999       buff.resize(1);
1000       buff[0] = values[0] ? -1 : 0;
1001     }
1002   }
1003 
1004   return DenseIntOrFPElementsAttr::getRaw(type, buff);
1005 }
1006 
1007 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1008                                          ArrayRef<StringRef> values) {
1009   assert(!type.getElementType().isIntOrFloat());
1010   return DenseStringElementsAttr::get(type, values);
1011 }
1012 
1013 /// Constructs a dense integer elements attribute from an array of APInt
1014 /// values. Each APInt value is expected to have the same bitwidth as the
1015 /// element type of 'type'.
1016 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1017                                          ArrayRef<APInt> values) {
1018   assert(type.getElementType().isIntOrIndex());
1019   assert(hasSameElementsOrSplat(type, values));
1020   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
1021   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1022 }
1023 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1024                                          ArrayRef<std::complex<APInt>> values) {
1025   ComplexType complex = llvm::cast<ComplexType>(type.getElementType());
1026   assert(llvm::isa<IntegerType>(complex.getElementType()));
1027   assert(hasSameElementsOrSplat(type, values));
1028   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
1029   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
1030                           values.size() * 2);
1031   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals);
1032 }
1033 
1034 // Constructs a dense float elements attribute from an array of APFloat
1035 // values. Each APFloat value is expected to have the same bitwidth as the
1036 // element type of 'type'.
1037 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1038                                          ArrayRef<APFloat> values) {
1039   assert(llvm::isa<FloatType>(type.getElementType()));
1040   assert(hasSameElementsOrSplat(type, values));
1041   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
1042   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1043 }
1044 DenseElementsAttr
1045 DenseElementsAttr::get(ShapedType type,
1046                        ArrayRef<std::complex<APFloat>> values) {
1047   ComplexType complex = llvm::cast<ComplexType>(type.getElementType());
1048   assert(llvm::isa<FloatType>(complex.getElementType()));
1049   assert(hasSameElementsOrSplat(type, values));
1050   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
1051                            values.size() * 2);
1052   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
1053   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals);
1054 }
1055 
1056 /// Construct a dense elements attribute from a raw buffer representing the
1057 /// data for this attribute. Users should generally not use this methods as
1058 /// the expected buffer format may not be a form the user expects.
1059 DenseElementsAttr
1060 DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) {
1061   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer);
1062 }
1063 
1064 /// Returns true if the given buffer is a valid raw buffer for the given type.
1065 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
1066                                          ArrayRef<char> rawBuffer,
1067                                          bool &detectedSplat) {
1068   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
1069   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
1070   int64_t numElements = type.getNumElements();
1071 
1072   // The initializer is always a splat if the result type has a single element.
1073   detectedSplat = numElements == 1;
1074 
1075   // Storage width of 1 is special as it is packed by the bit.
1076   if (storageWidth == 1) {
1077     // Check for a splat, or a buffer equal to the number of elements which
1078     // consists of either all 0's or all 1's.
1079     if (rawBuffer.size() == 1) {
1080       auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
1081       if (rawByte == 0 || rawByte == 0xff) {
1082         detectedSplat = true;
1083         return true;
1084       }
1085     }
1086 
1087     // This is a valid non-splat buffer if it has the right size.
1088     return rawBufferWidth == llvm::alignTo<8>(numElements);
1089   }
1090 
1091   // All other types are 8-bit aligned, so we can just check the buffer width
1092   // to know if only a single initializer element was passed in.
1093   if (rawBufferWidth == storageWidth) {
1094     detectedSplat = true;
1095     return true;
1096   }
1097 
1098   // The raw buffer is valid if it has the right size.
1099   return rawBufferWidth == storageWidth * numElements;
1100 }
1101 
1102 /// Check the information for a C++ data type, check if this type is valid for
1103 /// the current attribute. This method is used to verify specific type
1104 /// invariants that the templatized 'getValues' method cannot.
1105 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
1106                               bool isSigned) {
1107   // Make sure that the data element size is the same as the type element width.
1108   auto denseEltBitWidth = getDenseElementBitWidth(type);
1109   auto dataSize = static_cast<size_t>(dataEltSize * CHAR_BIT);
1110   if (denseEltBitWidth != dataSize) {
1111     LLVM_DEBUG(llvm::dbgs() << "expected dense element bit width "
1112                             << denseEltBitWidth << " to match data size "
1113                             << dataSize << " for type " << type << "\n");
1114     return false;
1115   }
1116 
1117   // Check that the element type is either float or integer or index.
1118   if (!isInt) {
1119     bool valid = llvm::isa<FloatType>(type);
1120     if (!valid)
1121       LLVM_DEBUG(llvm::dbgs()
1122                  << "expected float type when isInt is false, but found "
1123                  << type << "\n");
1124     return valid;
1125   }
1126   if (type.isIndex())
1127     return true;
1128 
1129   auto intType = llvm::dyn_cast<IntegerType>(type);
1130   if (!intType) {
1131     LLVM_DEBUG(llvm::dbgs()
1132                << "expected integer type when isInt is true, but found " << type
1133                << "\n");
1134     return false;
1135   }
1136 
1137   // Make sure signedness semantics is consistent.
1138   if (intType.isSignless())
1139     return true;
1140 
1141   bool valid = intType.isSigned() == isSigned;
1142   if (!valid)
1143     LLVM_DEBUG(llvm::dbgs() << "expected signedness " << isSigned
1144                             << " to match type " << type << "\n");
1145   return valid;
1146 }
1147 
1148 /// Defaults down the subclass implementation.
1149 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
1150                                                    ArrayRef<char> data,
1151                                                    int64_t dataEltSize,
1152                                                    bool isInt, bool isSigned) {
1153   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
1154                                                  isSigned);
1155 }
1156 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
1157                                                       ArrayRef<char> data,
1158                                                       int64_t dataEltSize,
1159                                                       bool isInt,
1160                                                       bool isSigned) {
1161   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
1162                                                     isInt, isSigned);
1163 }
1164 
1165 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
1166                                           bool isSigned) const {
1167   return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned);
1168 }
1169 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
1170                                        bool isSigned) const {
1171   return ::isValidIntOrFloat(
1172       llvm::cast<ComplexType>(getElementType()).getElementType(),
1173       dataEltSize / 2, isInt, isSigned);
1174 }
1175 
1176 /// Returns true if this attribute corresponds to a splat, i.e. if all element
1177 /// values are the same.
1178 bool DenseElementsAttr::isSplat() const {
1179   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
1180 }
1181 
1182 /// Return if the given complex type has an integer element type.
1183 static bool isComplexOfIntType(Type type) {
1184   return llvm::isa<IntegerType>(llvm::cast<ComplexType>(type).getElementType());
1185 }
1186 
1187 auto DenseElementsAttr::tryGetComplexIntValues() const
1188     -> FailureOr<iterator_range_impl<ComplexIntElementIterator>> {
1189   if (!isComplexOfIntType(getElementType()))
1190     return failure();
1191   return iterator_range_impl<ComplexIntElementIterator>(
1192       getType(), ComplexIntElementIterator(*this, 0),
1193       ComplexIntElementIterator(*this, getNumElements()));
1194 }
1195 
1196 auto DenseElementsAttr::tryGetFloatValues() const
1197     -> FailureOr<iterator_range_impl<FloatElementIterator>> {
1198   auto eltTy = llvm::dyn_cast<FloatType>(getElementType());
1199   if (!eltTy)
1200     return failure();
1201   const auto &elementSemantics = eltTy.getFloatSemantics();
1202   return iterator_range_impl<FloatElementIterator>(
1203       getType(), FloatElementIterator(elementSemantics, raw_int_begin()),
1204       FloatElementIterator(elementSemantics, raw_int_end()));
1205 }
1206 
1207 auto DenseElementsAttr::tryGetComplexFloatValues() const
1208     -> FailureOr<iterator_range_impl<ComplexFloatElementIterator>> {
1209   auto complexTy = llvm::dyn_cast<ComplexType>(getElementType());
1210   if (!complexTy)
1211     return failure();
1212   auto eltTy = llvm::dyn_cast<FloatType>(complexTy.getElementType());
1213   if (!eltTy)
1214     return failure();
1215   const auto &semantics = eltTy.getFloatSemantics();
1216   return iterator_range_impl<ComplexFloatElementIterator>(
1217       getType(), {semantics, {*this, 0}},
1218       {semantics, {*this, static_cast<size_t>(getNumElements())}});
1219 }
1220 
1221 /// Return the raw storage data held by this attribute.
1222 ArrayRef<char> DenseElementsAttr::getRawData() const {
1223   return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
1224 }
1225 
1226 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1227   return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
1228 }
1229 
1230 /// Return a new DenseElementsAttr that has the same data as the current
1231 /// attribute, but has been reshaped to 'newType'. The new type must have the
1232 /// same total number of elements as well as element type.
1233 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1234   ShapedType curType = getType();
1235   if (curType == newType)
1236     return *this;
1237 
1238   assert(newType.getElementType() == curType.getElementType() &&
1239          "expected the same element type");
1240   assert(newType.getNumElements() == curType.getNumElements() &&
1241          "expected the same number of elements");
1242   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1243 }
1244 
1245 DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
1246   assert(isSplat() && "expected a splat type");
1247 
1248   ShapedType curType = getType();
1249   if (curType == newType)
1250     return *this;
1251 
1252   assert(newType.getElementType() == curType.getElementType() &&
1253          "expected the same element type");
1254   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1255 }
1256 
1257 /// Return a new DenseElementsAttr that has the same data as the current
1258 /// attribute, but has bitcast elements such that it is now 'newType'. The new
1259 /// type must have the same shape and element types of the same bitwidth as the
1260 /// current type.
1261 DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
1262   ShapedType curType = getType();
1263   Type curElType = curType.getElementType();
1264   if (curElType == newElType)
1265     return *this;
1266 
1267   assert(getDenseElementBitWidth(newElType) ==
1268              getDenseElementBitWidth(curElType) &&
1269          "expected element types with the same bitwidth");
1270   return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
1271                                           getRawData());
1272 }
1273 
1274 DenseElementsAttr
1275 DenseElementsAttr::mapValues(Type newElementType,
1276                              function_ref<APInt(const APInt &)> mapping) const {
1277   return llvm::cast<DenseIntElementsAttr>(*this).mapValues(newElementType,
1278                                                            mapping);
1279 }
1280 
1281 DenseElementsAttr DenseElementsAttr::mapValues(
1282     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1283   return llvm::cast<DenseFPElementsAttr>(*this).mapValues(newElementType,
1284                                                           mapping);
1285 }
1286 
1287 ShapedType DenseElementsAttr::getType() const {
1288   return static_cast<const DenseElementsAttributeStorage *>(impl)->type;
1289 }
1290 
1291 Type DenseElementsAttr::getElementType() const {
1292   return getType().getElementType();
1293 }
1294 
1295 int64_t DenseElementsAttr::getNumElements() const {
1296   return getType().getNumElements();
1297 }
1298 
1299 //===----------------------------------------------------------------------===//
1300 // DenseIntOrFPElementsAttr
1301 //===----------------------------------------------------------------------===//
1302 
1303 /// Utility method to write a range of APInt values to a buffer.
1304 template <typename APRangeT>
1305 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1306                                 APRangeT &&values) {
1307   size_t numValues = llvm::size(values);
1308   data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT));
1309   size_t offset = 0;
1310   for (auto it = values.begin(), e = values.end(); it != e;
1311        ++it, offset += storageWidth) {
1312     assert((*it).getBitWidth() <= storageWidth);
1313     writeBits(data.data(), offset, *it);
1314   }
1315 
1316   // Handle the special encoding of splat of a boolean.
1317   if (numValues == 1 && (*values.begin()).getBitWidth() == 1)
1318     data[0] = data[0] ? -1 : 0;
1319 }
1320 
1321 /// Constructs a dense elements attribute from an array of raw APFloat values.
1322 /// Each APFloat value is expected to have the same bitwidth as the element
1323 /// type of 'type'. 'type' must be a vector or tensor with static shape.
1324 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1325                                                    size_t storageWidth,
1326                                                    ArrayRef<APFloat> values) {
1327   std::vector<char> data;
1328   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1329   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1330   return DenseIntOrFPElementsAttr::getRaw(type, data);
1331 }
1332 
1333 /// Constructs a dense elements attribute from an array of raw APInt values.
1334 /// Each APInt value is expected to have the same bitwidth as the element type
1335 /// of 'type'.
1336 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1337                                                    size_t storageWidth,
1338                                                    ArrayRef<APInt> values) {
1339   std::vector<char> data;
1340   writeAPIntsToBuffer(storageWidth, data, values);
1341   return DenseIntOrFPElementsAttr::getRaw(type, data);
1342 }
1343 
1344 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1345                                                    ArrayRef<char> data) {
1346   assert(type.hasStaticShape() && "type must have static shape");
1347   bool isSplat = false;
1348   bool isValid = isValidRawBuffer(type, data, isSplat);
1349   assert(isValid);
1350   (void)isValid;
1351   return Base::get(type.getContext(), type, data, isSplat);
1352 }
1353 
1354 /// Overload of the raw 'get' method that asserts that the given type is of
1355 /// complex type. This method is used to verify type invariants that the
1356 /// templatized 'get' method cannot.
1357 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1358                                                           ArrayRef<char> data,
1359                                                           int64_t dataEltSize,
1360                                                           bool isInt,
1361                                                           bool isSigned) {
1362   assert(::isValidIntOrFloat(
1363              llvm::cast<ComplexType>(type.getElementType()).getElementType(),
1364              dataEltSize / 2, isInt, isSigned) &&
1365          "Try re-running with -debug-only=builtinattributes");
1366 
1367   int64_t numElements = data.size() / dataEltSize;
1368   (void)numElements;
1369   assert(numElements == 1 || numElements == type.getNumElements());
1370   return getRaw(type, data);
1371 }
1372 
1373 /// Overload of the 'getRaw' method that asserts that the given type is of
1374 /// integer type. This method is used to verify type invariants that the
1375 /// templatized 'get' method cannot.
1376 DenseElementsAttr
1377 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1378                                            int64_t dataEltSize, bool isInt,
1379                                            bool isSigned) {
1380   assert(::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt,
1381                              isSigned) &&
1382          "Try re-running with -debug-only=builtinattributes");
1383 
1384   int64_t numElements = data.size() / dataEltSize;
1385   assert(numElements == 1 || numElements == type.getNumElements());
1386   (void)numElements;
1387   return getRaw(type, data);
1388 }
1389 
1390 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1391     const char *inRawData, char *outRawData, size_t elementBitWidth,
1392     size_t numElements) {
1393   using llvm::support::ulittle16_t;
1394   using llvm::support::ulittle32_t;
1395   using llvm::support::ulittle64_t;
1396 
1397   assert(llvm::endianness::native == llvm::endianness::big);
1398   // NOLINT to avoid warning message about replacing by static_assert()
1399 
1400   // Following std::copy_n always converts endianness on BE machine.
1401   switch (elementBitWidth) {
1402   case 16: {
1403     const ulittle16_t *inRawDataPos =
1404         reinterpret_cast<const ulittle16_t *>(inRawData);
1405     uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1406     std::copy_n(inRawDataPos, numElements, outDataPos);
1407     break;
1408   }
1409   case 32: {
1410     const ulittle32_t *inRawDataPos =
1411         reinterpret_cast<const ulittle32_t *>(inRawData);
1412     uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1413     std::copy_n(inRawDataPos, numElements, outDataPos);
1414     break;
1415   }
1416   case 64: {
1417     const ulittle64_t *inRawDataPos =
1418         reinterpret_cast<const ulittle64_t *>(inRawData);
1419     uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1420     std::copy_n(inRawDataPos, numElements, outDataPos);
1421     break;
1422   }
1423   default: {
1424     size_t nBytes = elementBitWidth / CHAR_BIT;
1425     for (size_t i = 0; i < nBytes; i++)
1426       std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1427     break;
1428   }
1429   }
1430 }
1431 
1432 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1433     ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1434     ShapedType type) {
1435   size_t numElements = type.getNumElements();
1436   Type elementType = type.getElementType();
1437   if (ComplexType complexTy = llvm::dyn_cast<ComplexType>(elementType)) {
1438     elementType = complexTy.getElementType();
1439     numElements = numElements * 2;
1440   }
1441   size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1442   assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1443          inRawData.size() <= outRawData.size());
1444   if (elementBitWidth <= CHAR_BIT)
1445     std::memcpy(outRawData.begin(), inRawData.begin(), inRawData.size());
1446   else
1447     convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1448                                     elementBitWidth, numElements);
1449 }
1450 
1451 //===----------------------------------------------------------------------===//
1452 // DenseFPElementsAttr
1453 //===----------------------------------------------------------------------===//
1454 
1455 template <typename Fn, typename Attr>
1456 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1457                                 Type newElementType,
1458                                 llvm::SmallVectorImpl<char> &data) {
1459   size_t bitWidth = getDenseElementBitWidth(newElementType);
1460   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1461 
1462   ShapedType newArrayType = inType.cloneWith(inType.getShape(), newElementType);
1463 
1464   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1465   data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT));
1466 
1467   // Functor used to process a single element value of the attribute.
1468   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1469     auto newInt = mapping(value);
1470     assert(newInt.getBitWidth() == bitWidth);
1471     writeBits(data.data(), index * storageBitWidth, newInt);
1472   };
1473 
1474   // Check for the splat case.
1475   if (attr.isSplat()) {
1476     if (bitWidth == 1) {
1477       // Handle the special encoding of splat of bool.
1478       data[0] = mapping(*attr.begin()).isZero() ? 0 : -1;
1479     } else {
1480       processElt(*attr.begin(), /*index=*/0);
1481     }
1482     return newArrayType;
1483   }
1484 
1485   // Otherwise, process all of the element values.
1486   uint64_t elementIdx = 0;
1487   for (auto value : attr)
1488     processElt(value, elementIdx++);
1489   return newArrayType;
1490 }
1491 
1492 DenseElementsAttr DenseFPElementsAttr::mapValues(
1493     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1494   llvm::SmallVector<char, 8> elementData;
1495   auto newArrayType =
1496       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1497 
1498   return getRaw(newArrayType, elementData);
1499 }
1500 
1501 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1502 bool DenseFPElementsAttr::classof(Attribute attr) {
1503   if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(attr))
1504     return llvm::isa<FloatType>(denseAttr.getType().getElementType());
1505   return false;
1506 }
1507 
1508 //===----------------------------------------------------------------------===//
1509 // DenseIntElementsAttr
1510 //===----------------------------------------------------------------------===//
1511 
1512 DenseElementsAttr DenseIntElementsAttr::mapValues(
1513     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1514   llvm::SmallVector<char, 8> elementData;
1515   auto newArrayType =
1516       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1517   return getRaw(newArrayType, elementData);
1518 }
1519 
1520 /// Method for supporting type inquiry through isa, cast and dyn_cast.
1521 bool DenseIntElementsAttr::classof(Attribute attr) {
1522   if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(attr))
1523     return denseAttr.getType().getElementType().isIntOrIndex();
1524   return false;
1525 }
1526 
1527 //===----------------------------------------------------------------------===//
1528 // DenseResourceElementsAttr
1529 //===----------------------------------------------------------------------===//
1530 
1531 DenseResourceElementsAttr
1532 DenseResourceElementsAttr::get(ShapedType type,
1533                                DenseResourceElementsHandle handle) {
1534   return Base::get(type.getContext(), type, handle);
1535 }
1536 
1537 DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type,
1538                                                          StringRef blobName,
1539                                                          AsmResourceBlob blob) {
1540   // Extract the builtin dialect resource manager from context and construct a
1541   // handle by inserting a new resource using the provided blob.
1542   auto &manager =
1543       DenseResourceElementsHandle::getManagerInterface(type.getContext());
1544   return get(type, manager.insert(blobName, std::move(blob)));
1545 }
1546 
1547 ArrayRef<char> DenseResourceElementsAttr::getData() {
1548   if (AsmResourceBlob *blob = this->getRawHandle().getBlob())
1549     return blob->getDataAs<char>();
1550   return {};
1551 }
1552 
1553 //===----------------------------------------------------------------------===//
1554 // DenseResourceElementsAttrBase
1555 
1556 namespace {
1557 /// Instantiations of this class provide utilities for interacting with native
1558 /// data types in the context of DenseResourceElementsAttr.
1559 template <typename T>
1560 struct DenseResourceAttrUtil;
1561 template <size_t width, bool isSigned>
1562 struct DenseResourceElementsAttrIntUtil {
1563   static bool checkElementType(Type eltType) {
1564     IntegerType type = llvm::dyn_cast<IntegerType>(eltType);
1565     if (!type || type.getWidth() != width)
1566       return false;
1567     return isSigned ? !type.isUnsigned() : !type.isSigned();
1568   }
1569 };
1570 template <>
1571 struct DenseResourceAttrUtil<bool> {
1572   static bool checkElementType(Type eltType) {
1573     return eltType.isSignlessInteger(1);
1574   }
1575 };
1576 template <>
1577 struct DenseResourceAttrUtil<int8_t>
1578     : public DenseResourceElementsAttrIntUtil<8, true> {};
1579 template <>
1580 struct DenseResourceAttrUtil<uint8_t>
1581     : public DenseResourceElementsAttrIntUtil<8, false> {};
1582 template <>
1583 struct DenseResourceAttrUtil<int16_t>
1584     : public DenseResourceElementsAttrIntUtil<16, true> {};
1585 template <>
1586 struct DenseResourceAttrUtil<uint16_t>
1587     : public DenseResourceElementsAttrIntUtil<16, false> {};
1588 template <>
1589 struct DenseResourceAttrUtil<int32_t>
1590     : public DenseResourceElementsAttrIntUtil<32, true> {};
1591 template <>
1592 struct DenseResourceAttrUtil<uint32_t>
1593     : public DenseResourceElementsAttrIntUtil<32, false> {};
1594 template <>
1595 struct DenseResourceAttrUtil<int64_t>
1596     : public DenseResourceElementsAttrIntUtil<64, true> {};
1597 template <>
1598 struct DenseResourceAttrUtil<uint64_t>
1599     : public DenseResourceElementsAttrIntUtil<64, false> {};
1600 template <>
1601 struct DenseResourceAttrUtil<float> {
1602   static bool checkElementType(Type eltType) { return eltType.isF32(); }
1603 };
1604 template <>
1605 struct DenseResourceAttrUtil<double> {
1606   static bool checkElementType(Type eltType) { return eltType.isF64(); }
1607 };
1608 } // namespace
1609 
1610 template <typename T>
1611 DenseResourceElementsAttrBase<T>
1612 DenseResourceElementsAttrBase<T>::get(ShapedType type, StringRef blobName,
1613                                       AsmResourceBlob blob) {
1614   // Check that the blob is in the form we were expecting.
1615   assert(blob.getDataAlignment() == alignof(T) &&
1616          "alignment mismatch between expected alignment and blob alignment");
1617   assert(((blob.getData().size() % sizeof(T)) == 0) &&
1618          "size mismatch between expected element width and blob size");
1619   assert(DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) &&
1620          "invalid shape element type for provided type `T`");
1621   return llvm::cast<DenseResourceElementsAttrBase<T>>(
1622       DenseResourceElementsAttr::get(type, blobName, std::move(blob)));
1623 }
1624 
1625 template <typename T>
1626 std::optional<ArrayRef<T>>
1627 DenseResourceElementsAttrBase<T>::tryGetAsArrayRef() const {
1628   if (AsmResourceBlob *blob = this->getRawHandle().getBlob())
1629     return blob->template getDataAs<T>();
1630   return std::nullopt;
1631 }
1632 
1633 template <typename T>
1634 bool DenseResourceElementsAttrBase<T>::classof(Attribute attr) {
1635   auto resourceAttr = llvm::dyn_cast<DenseResourceElementsAttr>(attr);
1636   return resourceAttr && DenseResourceAttrUtil<T>::checkElementType(
1637                              resourceAttr.getElementType());
1638 }
1639 
1640 namespace mlir {
1641 namespace detail {
1642 // Explicit instantiation for all the supported DenseResourceElementsAttr.
1643 template class DenseResourceElementsAttrBase<bool>;
1644 template class DenseResourceElementsAttrBase<int8_t>;
1645 template class DenseResourceElementsAttrBase<int16_t>;
1646 template class DenseResourceElementsAttrBase<int32_t>;
1647 template class DenseResourceElementsAttrBase<int64_t>;
1648 template class DenseResourceElementsAttrBase<uint8_t>;
1649 template class DenseResourceElementsAttrBase<uint16_t>;
1650 template class DenseResourceElementsAttrBase<uint32_t>;
1651 template class DenseResourceElementsAttrBase<uint64_t>;
1652 template class DenseResourceElementsAttrBase<float>;
1653 template class DenseResourceElementsAttrBase<double>;
1654 } // namespace detail
1655 } // namespace mlir
1656 
1657 //===----------------------------------------------------------------------===//
1658 // SparseElementsAttr
1659 //===----------------------------------------------------------------------===//
1660 
1661 /// Get a zero APFloat for the given sparse attribute.
1662 APFloat SparseElementsAttr::getZeroAPFloat() const {
1663   auto eltType = llvm::cast<FloatType>(getElementType());
1664   return APFloat(eltType.getFloatSemantics());
1665 }
1666 
1667 /// Get a zero APInt for the given sparse attribute.
1668 APInt SparseElementsAttr::getZeroAPInt() const {
1669   auto eltType = llvm::cast<IntegerType>(getElementType());
1670   return APInt::getZero(eltType.getWidth());
1671 }
1672 
1673 /// Get a zero attribute for the given attribute type.
1674 Attribute SparseElementsAttr::getZeroAttr() const {
1675   auto eltType = getElementType();
1676 
1677   // Handle floating point elements.
1678   if (llvm::isa<FloatType>(eltType))
1679     return FloatAttr::get(eltType, 0);
1680 
1681   // Handle complex elements.
1682   if (auto complexTy = llvm::dyn_cast<ComplexType>(eltType)) {
1683     auto eltType = complexTy.getElementType();
1684     Attribute zero;
1685     if (llvm::isa<FloatType>(eltType))
1686       zero = FloatAttr::get(eltType, 0);
1687     else // must be integer
1688       zero = IntegerAttr::get(eltType, 0);
1689     return ArrayAttr::get(complexTy.getContext(),
1690                           ArrayRef<Attribute>{zero, zero});
1691   }
1692 
1693   // Handle string type.
1694   if (llvm::isa<DenseStringElementsAttr>(getValues()))
1695     return StringAttr::get("", eltType);
1696 
1697   // Otherwise, this is an integer.
1698   return IntegerAttr::get(eltType, 0);
1699 }
1700 
1701 /// Flatten, and return, all of the sparse indices in this attribute in
1702 /// row-major order.
1703 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1704   std::vector<ptrdiff_t> flatSparseIndices;
1705 
1706   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1707   // as a 1-D index array.
1708   auto sparseIndices = getIndices();
1709   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1710   if (sparseIndices.isSplat()) {
1711     SmallVector<uint64_t, 8> indices(getType().getRank(),
1712                                      *sparseIndexValues.begin());
1713     flatSparseIndices.push_back(getFlattenedIndex(indices));
1714     return flatSparseIndices;
1715   }
1716 
1717   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1718   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1719   size_t rank = getType().getRank();
1720   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1721     flatSparseIndices.push_back(getFlattenedIndex(
1722         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1723   return flatSparseIndices;
1724 }
1725 
1726 LogicalResult
1727 SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1728                            ShapedType type, DenseIntElementsAttr sparseIndices,
1729                            DenseElementsAttr values) {
1730   ShapedType valuesType = values.getType();
1731   if (valuesType.getRank() != 1)
1732     return emitError() << "expected 1-d tensor for sparse element values";
1733 
1734   // Verify the indices and values shape.
1735   ShapedType indicesType = sparseIndices.getType();
1736   auto emitShapeError = [&]() {
1737     return emitError() << "expected shape ([" << type.getShape()
1738                        << "]); inferred shape of indices literal (["
1739                        << indicesType.getShape()
1740                        << "]); inferred shape of values literal (["
1741                        << valuesType.getShape() << "])";
1742   };
1743   // Verify indices shape.
1744   size_t rank = type.getRank(), indicesRank = indicesType.getRank();
1745   if (indicesRank == 2) {
1746     if (indicesType.getDimSize(1) != static_cast<int64_t>(rank))
1747       return emitShapeError();
1748   } else if (indicesRank != 1 || rank != 1) {
1749     return emitShapeError();
1750   }
1751   // Verify the values shape.
1752   int64_t numSparseIndices = indicesType.getDimSize(0);
1753   if (numSparseIndices != valuesType.getDimSize(0))
1754     return emitShapeError();
1755 
1756   // Verify that the sparse indices are within the value shape.
1757   auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) {
1758     return emitError()
1759            << "sparse index #" << indexNum
1760            << " is not contained within the value shape, with index=[" << index
1761            << "], and type=" << type;
1762   };
1763 
1764   // Handle the case where the index values are a splat.
1765   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1766   if (sparseIndices.isSplat()) {
1767     SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin());
1768     if (!ElementsAttr::isValidIndex(type, indices))
1769       return emitIndexError(0, indices);
1770     return success();
1771   }
1772 
1773   // Otherwise, reinterpret each index as an ArrayRef.
1774   for (size_t i = 0, e = numSparseIndices; i != e; ++i) {
1775     ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank),
1776                              rank);
1777     if (!ElementsAttr::isValidIndex(type, index))
1778       return emitIndexError(i, index);
1779   }
1780 
1781   return success();
1782 }
1783 
1784 //===----------------------------------------------------------------------===//
1785 // DistinctAttr
1786 //===----------------------------------------------------------------------===//
1787 
1788 DistinctAttr DistinctAttr::create(Attribute referencedAttr) {
1789   return Base::get(referencedAttr.getContext(), referencedAttr);
1790 }
1791 
1792 Attribute DistinctAttr::getReferencedAttr() const {
1793   return getImpl()->referencedAttr;
1794 }
1795 
1796 //===----------------------------------------------------------------------===//
1797 // Attribute Utilities
1798 //===----------------------------------------------------------------------===//
1799 
1800 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
1801                                            int64_t offset,
1802                                            MLIRContext *context) {
1803   AffineExpr expr;
1804   unsigned nSymbols = 0;
1805 
1806   // AffineExpr for offset.
1807   // Static case.
1808   if (!ShapedType::isDynamic(offset)) {
1809     auto cst = getAffineConstantExpr(offset, context);
1810     expr = cst;
1811   } else {
1812     // Dynamic case, new symbol for the offset.
1813     auto sym = getAffineSymbolExpr(nSymbols++, context);
1814     expr = sym;
1815   }
1816 
1817   // AffineExpr for strides.
1818   for (const auto &en : llvm::enumerate(strides)) {
1819     auto dim = en.index();
1820     auto stride = en.value();
1821     auto d = getAffineDimExpr(dim, context);
1822     AffineExpr mult;
1823     // Static case.
1824     if (!ShapedType::isDynamic(stride))
1825       mult = getAffineConstantExpr(stride, context);
1826     else
1827       // Dynamic case, new symbol for each new stride.
1828       mult = getAffineSymbolExpr(nSymbols++, context);
1829     expr = expr + d * mult;
1830   }
1831 
1832   return AffineMap::get(strides.size(), nSymbols, expr);
1833 }
1834