xref: /llvm-project/mlir/lib/Support/StorageUniquer.cpp (revision ceff9524f92445b3110fb79f0d6239713b5df153)
1 //===- StorageUniquer.cpp - Common Storage Class Uniquer ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Support/StorageUniquer.h"
10 
11 #include "mlir/Support/LLVM.h"
12 #include "mlir/Support/ThreadLocalCache.h"
13 #include "mlir/Support/TypeID.h"
14 #include "llvm/Support/RWMutex.h"
15 
16 using namespace mlir;
17 using namespace mlir::detail;
18 
19 namespace {
20 /// This class represents a uniquer for storage instances of a specific type
21 /// that has parametric storage. It contains all of the necessary data to unique
22 /// storage instances in a thread safe way. This allows for the main uniquer to
23 /// bucket each of the individual sub-types removing the need to lock the main
24 /// uniquer itself.
25 class ParametricStorageUniquer {
26 public:
27   using BaseStorage = StorageUniquer::BaseStorage;
28   using StorageAllocator = StorageUniquer::StorageAllocator;
29 
30   /// A lookup key for derived instances of storage objects.
31   struct LookupKey {
32     /// The known hash value of the key.
33     unsigned hashValue;
34 
35     /// An equality function for comparing with an existing storage instance.
36     function_ref<bool(const BaseStorage *)> isEqual;
37   };
38 
39 private:
40   /// A utility wrapper object representing a hashed storage object. This class
41   /// contains a storage object and an existing computed hash value.
42   struct HashedStorage {
HashedStorage__anon242314e60111::ParametricStorageUniquer::HashedStorage43     HashedStorage(unsigned hashValue = 0, BaseStorage *storage = nullptr)
44         : hashValue(hashValue), storage(storage) {}
45     unsigned hashValue;
46     BaseStorage *storage;
47   };
48 
49   /// Storage info for derived TypeStorage objects.
50   struct StorageKeyInfo {
getEmptyKey__anon242314e60111::ParametricStorageUniquer::StorageKeyInfo51     static inline HashedStorage getEmptyKey() {
52       return HashedStorage(0, DenseMapInfo<BaseStorage *>::getEmptyKey());
53     }
getTombstoneKey__anon242314e60111::ParametricStorageUniquer::StorageKeyInfo54     static inline HashedStorage getTombstoneKey() {
55       return HashedStorage(0, DenseMapInfo<BaseStorage *>::getTombstoneKey());
56     }
57 
getHashValue__anon242314e60111::ParametricStorageUniquer::StorageKeyInfo58     static inline unsigned getHashValue(const HashedStorage &key) {
59       return key.hashValue;
60     }
getHashValue__anon242314e60111::ParametricStorageUniquer::StorageKeyInfo61     static inline unsigned getHashValue(const LookupKey &key) {
62       return key.hashValue;
63     }
64 
isEqual__anon242314e60111::ParametricStorageUniquer::StorageKeyInfo65     static inline bool isEqual(const HashedStorage &lhs,
66                                const HashedStorage &rhs) {
67       return lhs.storage == rhs.storage;
68     }
isEqual__anon242314e60111::ParametricStorageUniquer::StorageKeyInfo69     static inline bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) {
70       if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
71         return false;
72       // Invoke the equality function on the lookup key.
73       return lhs.isEqual(rhs.storage);
74     }
75   };
76   using StorageTypeSet = DenseSet<HashedStorage, StorageKeyInfo>;
77 
78   /// This class represents a single shard of the uniquer. The uniquer uses a
79   /// set of shards to allow for multiple threads to create instances with less
80   /// lock contention.
81   struct Shard {
82     /// The set containing the allocated storage instances.
83     StorageTypeSet instances;
84 
85 #if LLVM_ENABLE_THREADS != 0
86     /// A mutex to keep uniquing thread-safe.
87     llvm::sys::SmartRWMutex<true> mutex;
88 #endif
89   };
90 
91   /// Get or create an instance of a param derived type in an thread-unsafe
92   /// fashion.
getOrCreateUnsafe(Shard & shard,LookupKey & key,function_ref<BaseStorage * ()> ctorFn)93   BaseStorage *getOrCreateUnsafe(Shard &shard, LookupKey &key,
94                                  function_ref<BaseStorage *()> ctorFn) {
95     auto existing = shard.instances.insert_as({key.hashValue}, key);
96     BaseStorage *&storage = existing.first->storage;
97     if (existing.second)
98       storage = ctorFn();
99     return storage;
100   }
101 
102   /// Destroy all of the storage instances within the given shard.
destroyShardInstances(Shard & shard)103   void destroyShardInstances(Shard &shard) {
104     if (!destructorFn)
105       return;
106     for (HashedStorage &instance : shard.instances)
107       destructorFn(instance.storage);
108   }
109 
110 public:
111 #if LLVM_ENABLE_THREADS != 0
112   /// Initialize the storage uniquer with a given number of storage shards to
113   /// use. The provided shard number is required to be a valid power of 2. The
114   /// destructor function is used to destroy any allocated storage instances.
ParametricStorageUniquer(function_ref<void (BaseStorage *)> destructorFn,size_t numShards=8)115   ParametricStorageUniquer(function_ref<void(BaseStorage *)> destructorFn,
116                            size_t numShards = 8)
117       : shards(new std::atomic<Shard *>[numShards]), numShards(numShards),
118         destructorFn(destructorFn) {
119     assert(llvm::isPowerOf2_64(numShards) &&
120            "the number of shards is required to be a power of 2");
121     for (size_t i = 0; i < numShards; i++)
122       shards[i].store(nullptr, std::memory_order_relaxed);
123   }
~ParametricStorageUniquer()124   ~ParametricStorageUniquer() {
125     // Free all of the allocated shards.
126     for (size_t i = 0; i != numShards; ++i) {
127       if (Shard *shard = shards[i].load()) {
128         destroyShardInstances(*shard);
129         delete shard;
130       }
131     }
132   }
133   /// Get or create an instance of a parametric type.
getOrCreate(bool threadingIsEnabled,unsigned hashValue,function_ref<bool (const BaseStorage *)> isEqual,function_ref<BaseStorage * ()> ctorFn)134   BaseStorage *getOrCreate(bool threadingIsEnabled, unsigned hashValue,
135                            function_ref<bool(const BaseStorage *)> isEqual,
136                            function_ref<BaseStorage *()> ctorFn) {
137     Shard &shard = getShard(hashValue);
138     ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual};
139     if (!threadingIsEnabled)
140       return getOrCreateUnsafe(shard, lookupKey, ctorFn);
141 
142     // Check for a instance of this object in the local cache.
143     auto localIt = localCache->insert_as({hashValue}, lookupKey);
144     BaseStorage *&localInst = localIt.first->storage;
145     if (localInst)
146       return localInst;
147 
148     // Check for an existing instance in read-only mode.
149     {
150       llvm::sys::SmartScopedReader<true> typeLock(shard.mutex);
151       auto it = shard.instances.find_as(lookupKey);
152       if (it != shard.instances.end())
153         return localInst = it->storage;
154     }
155 
156     // Acquire a writer-lock so that we can safely create the new storage
157     // instance.
158     llvm::sys::SmartScopedWriter<true> typeLock(shard.mutex);
159     return localInst = getOrCreateUnsafe(shard, lookupKey, ctorFn);
160   }
161 
162   /// Run a mutation function on the provided storage object in a thread-safe
163   /// way.
mutate(bool threadingIsEnabled,BaseStorage * storage,function_ref<LogicalResult ()> mutationFn)164   LogicalResult mutate(bool threadingIsEnabled, BaseStorage *storage,
165                        function_ref<LogicalResult()> mutationFn) {
166     if (!threadingIsEnabled)
167       return mutationFn();
168 
169     // Get a shard to use for mutating this storage instance. It doesn't need to
170     // be the same shard as the original allocation, but does need to be
171     // deterministic.
172     Shard &shard = getShard(llvm::hash_value(storage));
173     llvm::sys::SmartScopedWriter<true> lock(shard.mutex);
174     return mutationFn();
175   }
176 
177 private:
178   /// Return the shard used for the given hash value.
getShard(unsigned hashValue)179   Shard &getShard(unsigned hashValue) {
180     // Get a shard number from the provided hashvalue.
181     unsigned shardNum = hashValue & (numShards - 1);
182 
183     // Try to acquire an already initialized shard.
184     Shard *shard = shards[shardNum].load(std::memory_order_acquire);
185     if (shard)
186       return *shard;
187 
188     // Otherwise, try to allocate a new shard.
189     Shard *newShard = new Shard();
190     if (shards[shardNum].compare_exchange_strong(shard, newShard))
191       return *newShard;
192 
193     // If one was allocated before we can initialize ours, delete ours.
194     delete newShard;
195     return *shard;
196   }
197 
198   /// A thread local cache for storage objects. This helps to reduce the lock
199   /// contention when an object already existing in the cache.
200   ThreadLocalCache<StorageTypeSet> localCache;
201 
202   /// A set of uniquer shards to allow for further bucketing accesses for
203   /// instances of this storage type. Each shard is lazily initialized to reduce
204   /// the overhead when only a small amount of shards are in use.
205   std::unique_ptr<std::atomic<Shard *>[]> shards;
206 
207   /// The number of available shards.
208   size_t numShards;
209 
210   /// Function to used to destruct any allocated storage instances.
211   function_ref<void(BaseStorage *)> destructorFn;
212 
213 #else
214   /// If multi-threading is disabled, ignore the shard parameter as we will
215   /// always use one shard. The destructor function is used to destroy any
216   /// allocated storage instances.
217   ParametricStorageUniquer(function_ref<void(BaseStorage *)> destructorFn,
218                            size_t numShards = 0)
219       : destructorFn(destructorFn) {}
220   ~ParametricStorageUniquer() { destroyShardInstances(shard); }
221 
222   /// Get or create an instance of a parametric type.
223   BaseStorage *
224   getOrCreate(bool threadingIsEnabled, unsigned hashValue,
225               function_ref<bool(const BaseStorage *)> isEqual,
226               function_ref<BaseStorage *()> ctorFn) {
227     ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual};
228     return getOrCreateUnsafe(shard, lookupKey, ctorFn);
229   }
230   /// Run a mutation function on the provided storage object in a thread-safe
231   /// way.
232   LogicalResult
233   mutate(bool threadingIsEnabled, BaseStorage *storage,
234          function_ref<LogicalResult()> mutationFn) {
235     return mutationFn();
236   }
237 
238 private:
239   /// The main uniquer shard that is used for allocating storage instances.
240   Shard shard;
241 
242   /// Function to used to destruct any allocated storage instances.
243   function_ref<void(BaseStorage *)> destructorFn;
244 #endif
245 };
246 } // namespace
247 
248 namespace mlir {
249 namespace detail {
250 /// This is the implementation of the StorageUniquer class.
251 struct StorageUniquerImpl {
252   using BaseStorage = StorageUniquer::BaseStorage;
253   using StorageAllocator = StorageUniquer::StorageAllocator;
254 
255   //===--------------------------------------------------------------------===//
256   // Parametric Storage
257   //===--------------------------------------------------------------------===//
258 
259   /// Check if an instance of a parametric storage class exists.
hasParametricStoragemlir::detail::StorageUniquerImpl260   bool hasParametricStorage(TypeID id) { return parametricUniquers.count(id); }
261 
262   /// Get or create an instance of a parametric type.
263   BaseStorage *
getOrCreatemlir::detail::StorageUniquerImpl264   getOrCreate(TypeID id, unsigned hashValue,
265               function_ref<bool(const BaseStorage *)> isEqual,
266               function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
267     assert(parametricUniquers.count(id) &&
268            "creating unregistered storage instance");
269     ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
270     return storageUniquer.getOrCreate(
271         threadingIsEnabled, hashValue, isEqual,
272         [&] { return ctorFn(getThreadSafeAllocator()); });
273   }
274 
275   /// Run a mutation function on the provided storage object in a thread-safe
276   /// way.
277   LogicalResult
mutatemlir::detail::StorageUniquerImpl278   mutate(TypeID id, BaseStorage *storage,
279          function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
280     assert(parametricUniquers.count(id) &&
281            "mutating unregistered storage instance");
282     ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
283     return storageUniquer.mutate(threadingIsEnabled, storage, [&] {
284       return mutationFn(getThreadSafeAllocator());
285     });
286   }
287 
288   /// Return an allocator that can be used to safely allocate instances on the
289   /// current thread.
getThreadSafeAllocatormlir::detail::StorageUniquerImpl290   StorageAllocator &getThreadSafeAllocator() {
291 #if LLVM_ENABLE_THREADS != 0
292     if (!threadingIsEnabled)
293       return allocator;
294 
295     // If the allocator has not been initialized, create a new one.
296     StorageAllocator *&threadAllocator = threadSafeAllocator.get();
297     if (!threadAllocator) {
298       threadAllocator = new StorageAllocator();
299 
300       // Record this allocator, given that we don't want it to be destroyed when
301       // the thread dies.
302       llvm::sys::SmartScopedLock<true> lock(threadAllocatorMutex);
303       threadAllocators.push_back(
304           std::unique_ptr<StorageAllocator>(threadAllocator));
305     }
306 
307     return *threadAllocator;
308 #else
309     return allocator;
310 #endif
311   }
312 
313   //===--------------------------------------------------------------------===//
314   // Singleton Storage
315   //===--------------------------------------------------------------------===//
316 
317   /// Get or create an instance of a singleton storage class.
getSingletonmlir::detail::StorageUniquerImpl318   BaseStorage *getSingleton(TypeID id) {
319     BaseStorage *singletonInstance = singletonInstances[id];
320     assert(singletonInstance && "expected singleton instance to exist");
321     return singletonInstance;
322   }
323 
324   /// Check if an instance of a singleton storage class exists.
hasSingletonmlir::detail::StorageUniquerImpl325   bool hasSingleton(TypeID id) const { return singletonInstances.count(id); }
326 
327   //===--------------------------------------------------------------------===//
328   // Instance Storage
329   //===--------------------------------------------------------------------===//
330 
331 #if LLVM_ENABLE_THREADS != 0
332   /// A thread local set of allocators used for uniquing parametric instances,
333   /// or other data allocated in thread volatile situations.
334   ThreadLocalCache<StorageAllocator *> threadSafeAllocator;
335 
336   /// All of the allocators that have been created for thread based allocation.
337   std::vector<std::unique_ptr<StorageAllocator>> threadAllocators;
338 
339   /// A mutex used for safely adding a new thread allocator.
340   llvm::sys::SmartMutex<true> threadAllocatorMutex;
341 #endif
342 
343   /// Main allocator used for uniquing singleton instances, and other state when
344   /// thread safety is guaranteed.
345   StorageAllocator allocator;
346 
347   /// Map of type ids to the storage uniquer to use for registered objects.
348   DenseMap<TypeID, std::unique_ptr<ParametricStorageUniquer>>
349       parametricUniquers;
350 
351   /// Map of type ids to a singleton instance when the storage class is a
352   /// singleton.
353   DenseMap<TypeID, BaseStorage *> singletonInstances;
354 
355   /// Flag specifying if multi-threading is enabled within the uniquer.
356   bool threadingIsEnabled = true;
357 };
358 } // namespace detail
359 } // namespace mlir
360 
StorageUniquer()361 StorageUniquer::StorageUniquer() : impl(new StorageUniquerImpl()) {}
362 StorageUniquer::~StorageUniquer() = default;
363 
364 /// Set the flag specifying if multi-threading is disabled within the uniquer.
disableMultithreading(bool disable)365 void StorageUniquer::disableMultithreading(bool disable) {
366   impl->threadingIsEnabled = !disable;
367 }
368 
369 /// Implementation for getting/creating an instance of a derived type with
370 /// parametric storage.
getParametricStorageTypeImpl(TypeID id,unsigned hashValue,function_ref<bool (const BaseStorage *)> isEqual,function_ref<BaseStorage * (StorageAllocator &)> ctorFn)371 auto StorageUniquer::getParametricStorageTypeImpl(
372     TypeID id, unsigned hashValue,
373     function_ref<bool(const BaseStorage *)> isEqual,
374     function_ref<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
375   return impl->getOrCreate(id, hashValue, isEqual, ctorFn);
376 }
377 
378 /// Implementation for registering an instance of a derived type with
379 /// parametric storage.
registerParametricStorageTypeImpl(TypeID id,function_ref<void (BaseStorage *)> destructorFn)380 void StorageUniquer::registerParametricStorageTypeImpl(
381     TypeID id, function_ref<void(BaseStorage *)> destructorFn) {
382   impl->parametricUniquers.try_emplace(
383       id, std::make_unique<ParametricStorageUniquer>(destructorFn));
384 }
385 
386 /// Implementation for getting an instance of a derived type with default
387 /// storage.
getSingletonImpl(TypeID id)388 auto StorageUniquer::getSingletonImpl(TypeID id) -> BaseStorage * {
389   return impl->getSingleton(id);
390 }
391 
392 /// Test is the storage singleton is initialized.
isSingletonStorageInitialized(TypeID id)393 bool StorageUniquer::isSingletonStorageInitialized(TypeID id) {
394   return impl->hasSingleton(id);
395 }
396 
397 /// Test is the parametric storage is initialized.
isParametricStorageInitialized(TypeID id)398 bool StorageUniquer::isParametricStorageInitialized(TypeID id) {
399   return impl->hasParametricStorage(id);
400 }
401 
402 /// Implementation for registering an instance of a derived type with default
403 /// storage.
registerSingletonImpl(TypeID id,function_ref<BaseStorage * (StorageAllocator &)> ctorFn)404 void StorageUniquer::registerSingletonImpl(
405     TypeID id, function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
406   assert(!impl->singletonInstances.count(id) &&
407          "storage class already registered");
408   impl->singletonInstances.try_emplace(id, ctorFn(impl->allocator));
409 }
410 
411 /// Implementation for mutating an instance of a derived storage.
mutateImpl(TypeID id,BaseStorage * storage,function_ref<LogicalResult (StorageAllocator &)> mutationFn)412 LogicalResult StorageUniquer::mutateImpl(
413     TypeID id, BaseStorage *storage,
414     function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
415   return impl->mutate(id, storage, mutationFn);
416 }
417