xref: /llvm-project/mlir/lib/IR/AttributeDetail.h (revision dec8055a1e71fe25d4b85416ede742e8fdfaf3f0)
1 //===- AttributeDetail.h - MLIR Affine Map details Class --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This holds implementation details of Attribute.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef ATTRIBUTEDETAIL_H_
14 #define ATTRIBUTEDETAIL_H_
15 
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/AttributeSupport.h"
18 #include "mlir/IR/BuiltinAttributes.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/IntegerSet.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/Support/StorageUniquer.h"
23 #include "mlir/Support/ThreadLocalCache.h"
24 #include "llvm/ADT/APFloat.h"
25 #include "llvm/ADT/PointerIntPair.h"
26 #include "llvm/Support/TrailingObjects.h"
27 
28 namespace mlir {
29 namespace detail {
30 
31 //===----------------------------------------------------------------------===//
32 // Elements Attributes
33 //===----------------------------------------------------------------------===//
34 
35 /// Return the bit width which DenseElementsAttr should use for this type.
getDenseElementBitWidth(Type eltType)36 inline size_t getDenseElementBitWidth(Type eltType) {
37   // Align the width for complex to 8 to make storage and interpretation easier.
38   if (ComplexType comp = llvm::dyn_cast<ComplexType>(eltType))
39     return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
40   if (eltType.isIndex())
41     return IndexType::kInternalStorageBitWidth;
42   return eltType.getIntOrFloatBitWidth();
43 }
44 
45 /// An attribute representing a reference to a dense vector or tensor object.
46 struct DenseElementsAttributeStorage : public AttributeStorage {
47 public:
DenseElementsAttributeStorageDenseElementsAttributeStorage48   DenseElementsAttributeStorage(ShapedType type, bool isSplat)
49       : type(type), isSplat(isSplat) {}
50 
51   ShapedType type;
52   bool isSplat;
53 };
54 
55 /// An attribute representing a reference to a dense vector or tensor object.
56 struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
57   DenseIntOrFPElementsAttrStorage(ShapedType ty, ArrayRef<char> data,
58                                   bool isSplat = false)
DenseElementsAttributeStorageDenseIntOrFPElementsAttrStorage59       : DenseElementsAttributeStorage(ty, isSplat), data(data) {}
60 
61   struct KeyTy {
62     KeyTy(ShapedType type, ArrayRef<char> data, llvm::hash_code hashCode,
63           bool isSplat = false)
typeDenseIntOrFPElementsAttrStorage::KeyTy64         : type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
65 
66     /// The type of the dense elements.
67     ShapedType type;
68 
69     /// The raw buffer for the data storage.
70     ArrayRef<char> data;
71 
72     /// The computed hash code for the storage data.
73     llvm::hash_code hashCode;
74 
75     /// A boolean that indicates if this data is a splat or not.
76     bool isSplat;
77   };
78 
79   /// Compare this storage instance with the provided key.
80   bool operator==(const KeyTy &key) const {
81     return key.type == type && key.data == data;
82   }
83 
84   /// Construct a key from a shaped type, raw data buffer, and a flag that
85   /// signals if the data is already known to be a splat. Callers to this
86   /// function are expected to tag preknown splat values when possible, e.g. one
87   /// element shapes.
getKeyDenseIntOrFPElementsAttrStorage88   static KeyTy getKey(ShapedType ty, ArrayRef<char> data, bool isKnownSplat) {
89     // Handle an empty storage instance.
90     if (data.empty())
91       return KeyTy(ty, data, 0);
92 
93     // If the data is already known to be a splat, the key hash value is
94     // directly the data buffer.
95     bool isBoolData = ty.getElementType().isInteger(1);
96     if (isKnownSplat) {
97       if (isBoolData)
98         return getKeyForSplatBoolData(ty, data[0] != 0);
99       return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
100     }
101 
102     // Otherwise, we need to check if the data corresponds to a splat or not.
103 
104     // Handle the simple case of only one element.
105     size_t numElements = ty.getNumElements();
106     assert(numElements != 1 && "splat of 1 element should already be detected");
107 
108     // Handle boolean values directly as they are packed to 1-bit.
109     if (isBoolData)
110       return getKeyForBoolData(ty, data, numElements);
111 
112     size_t elementWidth = getDenseElementBitWidth(ty.getElementType());
113     // Non 1-bit dense elements are padded to 8-bits.
114     size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT);
115     assert(((data.size() / storageSize) == numElements) &&
116            "data does not hold expected number of elements");
117 
118     // Create the initial hash value with just the first element.
119     auto firstElt = data.take_front(storageSize);
120     auto hashVal = llvm::hash_value(firstElt);
121 
122     // Check to see if this storage represents a splat. If it doesn't then
123     // combine the hash for the data starting with the first non splat element.
124     for (size_t i = storageSize, e = data.size(); i != e; i += storageSize)
125       if (memcmp(data.data(), &data[i], storageSize))
126         return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
127 
128     // Otherwise, this is a splat so just return the hash of the first element.
129     return KeyTy(ty, firstElt, hashVal, /*isSplat=*/true);
130   }
131 
132   /// Construct a key with a set of boolean data.
getKeyForBoolDataDenseIntOrFPElementsAttrStorage133   static KeyTy getKeyForBoolData(ShapedType ty, ArrayRef<char> data,
134                                  size_t numElements) {
135     ArrayRef<char> splatData = data;
136     bool splatValue = splatData.front() & 1;
137 
138     // Check the simple case where the data matches the known splat value.
139     if (splatData == ArrayRef<char>(splatValue ? kSplatTrue : kSplatFalse))
140       return getKeyForSplatBoolData(ty, splatValue);
141 
142     // Handle the case where the potential splat value is 1 and the number of
143     // elements is non 8-bit aligned.
144     size_t numOddElements = numElements % CHAR_BIT;
145     if (splatValue && numOddElements != 0) {
146       // Check that all bits are set in the last value.
147       char lastElt = splatData.back();
148       if (lastElt != llvm::maskTrailingOnes<unsigned char>(numOddElements))
149         return KeyTy(ty, data, llvm::hash_value(data));
150 
151       // If this is the only element, the data is known to be a splat.
152       if (splatData.size() == 1)
153         return getKeyForSplatBoolData(ty, splatValue);
154       splatData = splatData.drop_back();
155     }
156 
157     // Check that the data buffer corresponds to a splat of the proper mask.
158     char mask = splatValue ? ~0 : 0;
159     return llvm::all_of(splatData, [mask](char c) { return c == mask; })
160                ? getKeyForSplatBoolData(ty, splatValue)
161                : KeyTy(ty, data, llvm::hash_value(data));
162   }
163 
164   /// Return a key to use for a boolean splat of the given value.
getKeyForSplatBoolDataDenseIntOrFPElementsAttrStorage165   static KeyTy getKeyForSplatBoolData(ShapedType type, bool splatValue) {
166     const char &splatData = splatValue ? kSplatTrue : kSplatFalse;
167     return KeyTy(type, splatData, llvm::hash_value(splatData),
168                  /*isSplat=*/true);
169   }
170 
171   /// Hash the key for the storage.
hashKeyDenseIntOrFPElementsAttrStorage172   static llvm::hash_code hashKey(const KeyTy &key) {
173     return llvm::hash_combine(key.type, key.hashCode);
174   }
175 
176   /// Construct a new storage instance.
177   static DenseIntOrFPElementsAttrStorage *
constructDenseIntOrFPElementsAttrStorage178   construct(AttributeStorageAllocator &allocator, KeyTy key) {
179     // If the data buffer is non-empty, we copy it into the allocator with a
180     // 64-bit alignment.
181     ArrayRef<char> copy, data = key.data;
182     if (!data.empty()) {
183       char *rawData = reinterpret_cast<char *>(
184           allocator.allocate(data.size(), alignof(uint64_t)));
185       std::memcpy(rawData, data.data(), data.size());
186       copy = ArrayRef<char>(rawData, data.size());
187     }
188 
189     return new (allocator.allocate<DenseIntOrFPElementsAttrStorage>())
190         DenseIntOrFPElementsAttrStorage(key.type, copy, key.isSplat);
191   }
192 
193   ArrayRef<char> data;
194 
195   /// The values used to denote a boolean splat value.
196   // This is not using constexpr declaration due to compilation failure
197   // encountered with MSVC where it would inline these values, which makes it
198   // unsafe to refer by reference in KeyTy.
199   static const char kSplatTrue;
200   static const char kSplatFalse;
201 };
202 
203 /// An attribute representing a reference to a dense vector or tensor object
204 /// containing strings.
205 struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage {
206   DenseStringElementsAttrStorage(ShapedType ty, ArrayRef<StringRef> data,
207                                  bool isSplat = false)
DenseElementsAttributeStorageDenseStringElementsAttrStorage208       : DenseElementsAttributeStorage(ty, isSplat), data(data) {}
209 
210   struct KeyTy {
211     KeyTy(ShapedType type, ArrayRef<StringRef> data, llvm::hash_code hashCode,
212           bool isSplat = false)
typeDenseStringElementsAttrStorage::KeyTy213         : type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
214 
215     /// The type of the dense elements.
216     ShapedType type;
217 
218     /// The raw buffer for the data storage.
219     ArrayRef<StringRef> data;
220 
221     /// The computed hash code for the storage data.
222     llvm::hash_code hashCode;
223 
224     /// A boolean that indicates if this data is a splat or not.
225     bool isSplat;
226   };
227 
228   /// Compare this storage instance with the provided key.
229   bool operator==(const KeyTy &key) const {
230     if (key.type != type)
231       return false;
232 
233     // Otherwise, we can default to just checking the data. StringRefs compare
234     // by contents.
235     return key.data == data;
236   }
237 
238   /// Construct a key from a shaped type, StringRef data buffer, and a flag that
239   /// signals if the data is already known to be a splat. Callers to this
240   /// function are expected to tag preknown splat values when possible, e.g. one
241   /// element shapes.
getKeyDenseStringElementsAttrStorage242   static KeyTy getKey(ShapedType ty, ArrayRef<StringRef> data,
243                       bool isKnownSplat) {
244     // Handle an empty storage instance.
245     if (data.empty())
246       return KeyTy(ty, data, 0);
247 
248     // If the data is already known to be a splat, the key hash value is
249     // directly the data buffer.
250     if (isKnownSplat)
251       return KeyTy(ty, data, llvm::hash_value(data.front()), isKnownSplat);
252 
253     // Handle the simple case of only one element.
254     assert(ty.getNumElements() != 1 &&
255            "splat of 1 element should already be detected");
256 
257     // Create the initial hash value with just the first element.
258     const auto &firstElt = data.front();
259     auto hashVal = llvm::hash_value(firstElt);
260 
261     // Check to see if this storage represents a splat. If it doesn't then
262     // combine the hash for the data starting with the first non splat element.
263     for (size_t i = 1, e = data.size(); i != e; i++)
264       if (firstElt != data[i])
265         return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
266 
267     // Otherwise, this is a splat so just return the hash of the first element.
268     return KeyTy(ty, data.take_front(), hashVal, /*isSplat=*/true);
269   }
270 
271   /// Hash the key for the storage.
hashKeyDenseStringElementsAttrStorage272   static llvm::hash_code hashKey(const KeyTy &key) {
273     return llvm::hash_combine(key.type, key.hashCode);
274   }
275 
276   /// Construct a new storage instance.
277   static DenseStringElementsAttrStorage *
constructDenseStringElementsAttrStorage278   construct(AttributeStorageAllocator &allocator, KeyTy key) {
279     // If the data buffer is non-empty, we copy it into the allocator with a
280     // 64-bit alignment.
281     ArrayRef<StringRef> copy, data = key.data;
282     if (data.empty()) {
283       return new (allocator.allocate<DenseStringElementsAttrStorage>())
284           DenseStringElementsAttrStorage(key.type, copy, key.isSplat);
285     }
286 
287     int numEntries = key.isSplat ? 1 : data.size();
288 
289     // Compute the amount data needed to store the ArrayRef and StringRef
290     // contents.
291     size_t dataSize = sizeof(StringRef) * numEntries;
292     for (int i = 0; i < numEntries; i++)
293       dataSize += data[i].size();
294 
295     char *rawData = reinterpret_cast<char *>(
296         allocator.allocate(dataSize, alignof(uint64_t)));
297 
298     // Setup a mutable array ref of our string refs so that we can update their
299     // contents.
300     auto mutableCopy = MutableArrayRef<StringRef>(
301         reinterpret_cast<StringRef *>(rawData), numEntries);
302     auto *stringData = rawData + numEntries * sizeof(StringRef);
303 
304     for (int i = 0; i < numEntries; i++) {
305       memcpy(stringData, data[i].data(), data[i].size());
306       mutableCopy[i] = StringRef(stringData, data[i].size());
307       stringData += data[i].size();
308     }
309 
310     copy =
311         ArrayRef<StringRef>(reinterpret_cast<StringRef *>(rawData), numEntries);
312 
313     return new (allocator.allocate<DenseStringElementsAttrStorage>())
314         DenseStringElementsAttrStorage(key.type, copy, key.isSplat);
315   }
316 
317   ArrayRef<StringRef> data;
318 };
319 
320 //===----------------------------------------------------------------------===//
321 // StringAttr
322 //===----------------------------------------------------------------------===//
323 
324 struct StringAttrStorage : public AttributeStorage {
StringAttrStorageStringAttrStorage325   StringAttrStorage(StringRef value, Type type)
326       : type(type), value(value), referencedDialect(nullptr) {}
327 
328   /// The hash key is a tuple of the parameter types.
329   using KeyTy = std::pair<StringRef, Type>;
330   bool operator==(const KeyTy &key) const {
331     return value == key.first && type == key.second;
332   }
hashKeyStringAttrStorage333   static ::llvm::hash_code hashKey(const KeyTy &key) {
334     return DenseMapInfo<KeyTy>::getHashValue(key);
335   }
336 
337   /// Define a construction method for creating a new instance of this
338   /// storage.
constructStringAttrStorage339   static StringAttrStorage *construct(AttributeStorageAllocator &allocator,
340                                       const KeyTy &key) {
341     return new (allocator.allocate<StringAttrStorage>())
342         StringAttrStorage(allocator.copyInto(key.first), key.second);
343   }
344 
345   /// Initialize the storage given an MLIRContext.
346   void initialize(MLIRContext *context);
347 
348   /// The type of the string.
349   Type type;
350   /// The raw string value.
351   StringRef value;
352   /// If the string value contains a dialect namespace prefix (e.g.
353   /// dialect.blah), this is the dialect referenced.
354   Dialect *referencedDialect;
355 };
356 
357 //===----------------------------------------------------------------------===//
358 // DistinctAttr
359 //===----------------------------------------------------------------------===//
360 
361 /// An attribute to store a distinct reference to another attribute.
362 struct DistinctAttrStorage : public AttributeStorage {
363   using KeyTy = Attribute;
364 
DistinctAttrStorageDistinctAttrStorage365   DistinctAttrStorage(Attribute referencedAttr)
366       : referencedAttr(referencedAttr) {}
367 
368   /// Returns the referenced attribute as key.
getAsKeyDistinctAttrStorage369   KeyTy getAsKey() const { return KeyTy(referencedAttr); }
370 
371   /// The referenced attribute.
372   Attribute referencedAttr;
373 };
374 
375 /// A specialized attribute uniquer for distinct attributes that always
376 /// allocates since the distinct attribute instances use the address of their
377 /// storage as unique identifier.
378 class DistinctAttributeUniquer {
379 public:
380   /// Creates a distinct attribute storage. Allocates every time since the
381   /// address of the storage serves as unique identifier.
382   template <typename T, typename... Args>
get(MLIRContext * context,Args &&...args)383   static T get(MLIRContext *context, Args &&...args) {
384     static_assert(std::is_same_v<typename T::ImplType, DistinctAttrStorage>,
385                   "expects a distinct attribute storage");
386     DistinctAttrStorage *storage = DistinctAttributeUniquer::allocateStorage(
387         context, std::forward<Args>(args)...);
388     storage->initializeAbstractAttribute(
389         AbstractAttribute::lookup(DistinctAttr::getTypeID(), context));
390     return storage;
391   }
392 
393 private:
394   /// Allocates a distinct attribute storage.
395   static DistinctAttrStorage *allocateStorage(MLIRContext *context,
396                                               Attribute referencedAttr);
397 };
398 
399 /// An allocator for distinct attribute storage instances. It uses thread local
400 /// bump pointer allocators stored in a thread local cache to ensure the storage
401 /// is freed after the destruction of the distinct attribute allocator.
402 class DistinctAttributeAllocator {
403 public:
404   DistinctAttributeAllocator() = default;
405 
406   DistinctAttributeAllocator(DistinctAttributeAllocator &&) = delete;
407   DistinctAttributeAllocator(const DistinctAttributeAllocator &) = delete;
408   DistinctAttributeAllocator &
409   operator=(const DistinctAttributeAllocator &) = delete;
410 
411   /// Allocates a distinct attribute storage using a thread local bump pointer
412   /// allocator to enable synchronization free parallel allocations.
allocate(Attribute referencedAttr)413   DistinctAttrStorage *allocate(Attribute referencedAttr) {
414     return new (allocatorCache.get().Allocate<DistinctAttrStorage>())
415         DistinctAttrStorage(referencedAttr);
416   }
417 
418 private:
419   ThreadLocalCache<llvm::BumpPtrAllocator> allocatorCache;
420 };
421 } // namespace detail
422 } // namespace mlir
423 
424 #endif // ATTRIBUTEDETAIL_H_
425