xref: /llvm-project/mlir/include/mlir/Support/StorageUniquer.h (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- StorageUniquer.h - Common Storage Class Uniquer ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_SUPPORT_STORAGEUNIQUER_H
10 #define MLIR_SUPPORT_STORAGEUNIQUER_H
11 
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/Support/TypeID.h"
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/DenseSet.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "llvm/Support/Allocator.h"
18 #include <utility>
19 
20 namespace mlir {
21 namespace detail {
22 struct StorageUniquerImpl;
23 
24 /// Trait to check if ImplTy provides a 'getKey' method with types 'Args'.
25 template <typename ImplTy, typename... Args>
26 using has_impltype_getkey_t = decltype(ImplTy::getKey(std::declval<Args>()...));
27 
28 /// Trait to check if ImplTy provides a 'hashKey' method for 'T'.
29 template <typename ImplTy, typename T>
30 using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
31 } // namespace detail
32 
33 /// A utility class to get or create instances of "storage classes". These
34 /// storage classes must derive from 'StorageUniquer::BaseStorage'.
35 ///
36 /// For non-parametric storage classes, i.e. singleton classes, nothing else is
37 /// needed. Instances of these classes can be created by calling `get` without
38 /// trailing arguments.
39 ///
40 /// Otherwise, the parametric storage classes may be created with `get`,
41 /// and must respect the following:
42 ///    - Define a type alias, KeyTy, to a type that uniquely identifies the
43 ///      instance of the storage class.
44 ///      * The key type must be constructible from the values passed into the
45 ///        getComplex call.
46 ///      * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
47 ///        storage class must define a hashing method:
48 ///         'static unsigned hashKey(const KeyTy &)'
49 ///
50 ///    - Provide a method, 'bool operator==(const KeyTy &) const', to
51 ///      compare the storage instance against an instance of the key type.
52 ///
53 ///    - Provide a static construction method:
54 ///        'DerivedStorage *construct(StorageAllocator &, const KeyTy &key)'
55 ///      that builds a unique instance of the derived storage. The arguments to
56 ///      this function are an allocator to store any uniqued data and the key
57 ///      type for this storage.
58 ///
59 ///    - Provide a cleanup method:
60 ///        'void cleanup()'
61 ///      that is called when erasing a storage instance. This should cleanup any
62 ///      fields of the storage as necessary and not attempt to free the memory
63 ///      of the storage itself.
64 ///
65 /// Storage classes may have an optional mutable component, which must not take
66 /// part in the unique immutable key. In this case, storage classes may be
67 /// mutated with `mutate` and must additionally respect the following:
68 ///    - Provide a mutation method:
69 ///        'LogicalResult mutate(StorageAllocator &, <...>)'
70 ///      that is called when mutating a storage instance. The first argument is
71 ///      an allocator to store any mutable data, and the remaining arguments are
72 ///      forwarded from the call site. The storage can be mutated at any time
73 ///      after creation. Care must be taken to avoid excessive mutation since
74 ///      the allocated storage can keep containing previous states. The return
75 ///      value of the function is used to indicate whether the mutation was
76 ///      successful, e.g., to limit the number of mutations or enable deferred
77 ///      one-time assignment of the mutable component.
78 ///
79 /// All storage classes must be registered with the uniquer via
80 /// `registerParametricStorageType` or `registerSingletonStorageType`
81 /// using an appropriate unique `TypeID` for the storage class.
82 class StorageUniquer {
83 public:
84   /// This class acts as the base storage that all storage classes must derived
85   /// from.
86   class alignas(8) BaseStorage {
87   protected:
88     BaseStorage() = default;
89   };
90 
91   /// This is a utility allocator used to allocate memory for instances of
92   /// derived types.
93   class StorageAllocator {
94   public:
95     /// Copy the specified array of elements into memory managed by our bump
96     /// pointer allocator.  This assumes the elements are all PODs.
97     template <typename T>
copyInto(ArrayRef<T> elements)98     ArrayRef<T> copyInto(ArrayRef<T> elements) {
99       if (elements.empty())
100         return std::nullopt;
101       auto result = allocator.Allocate<T>(elements.size());
102       std::uninitialized_copy(elements.begin(), elements.end(), result);
103       return ArrayRef<T>(result, elements.size());
104     }
105 
106     /// Copy the provided string into memory managed by our bump pointer
107     /// allocator.
copyInto(StringRef str)108     StringRef copyInto(StringRef str) {
109       if (str.empty())
110         return StringRef();
111 
112       char *result = allocator.Allocate<char>(str.size() + 1);
113       std::uninitialized_copy(str.begin(), str.end(), result);
114       result[str.size()] = 0;
115       return StringRef(result, str.size());
116     }
117 
118     /// Allocate an instance of the provided type.
119     template <typename T>
allocate()120     T *allocate() {
121       return allocator.Allocate<T>();
122     }
123 
124     /// Allocate 'size' bytes of 'alignment' aligned memory.
allocate(size_t size,size_t alignment)125     void *allocate(size_t size, size_t alignment) {
126       return allocator.Allocate(size, alignment);
127     }
128 
129     /// Returns true if this allocator allocated the provided object pointer.
allocated(const void * ptr)130     bool allocated(const void *ptr) {
131       return allocator.identifyObject(ptr).has_value();
132     }
133 
134   private:
135     /// The raw allocator for type storage objects.
136     llvm::BumpPtrAllocator allocator;
137   };
138 
139   StorageUniquer();
140   ~StorageUniquer();
141 
142   /// Set the flag specifying if multi-threading is disabled within the uniquer.
143   void disableMultithreading(bool disable = true);
144 
145   /// Register a new parametric storage class, this is necessary to create
146   /// instances of this class type. `id` is the type identifier that will be
147   /// used to identify this type when creating instances of it via 'get'.
148   template <typename Storage>
registerParametricStorageType(TypeID id)149   void registerParametricStorageType(TypeID id) {
150     // If the storage is trivially destructible, we don't need a destructor
151     // function.
152     if constexpr (std::is_trivially_destructible_v<Storage>)
153       return registerParametricStorageTypeImpl(id, nullptr);
154     registerParametricStorageTypeImpl(id, [](BaseStorage *storage) {
155       static_cast<Storage *>(storage)->~Storage();
156     });
157   }
158   /// Utility override when the storage type represents the type id.
159   template <typename Storage>
registerParametricStorageType()160   void registerParametricStorageType() {
161     registerParametricStorageType<Storage>(TypeID::get<Storage>());
162   }
163   /// Register a new singleton storage class, this is necessary to get the
164   /// singletone instance. `id` is the type identifier that will be used to
165   /// access the singleton instance via 'get'. An optional initialization
166   /// function may also be provided to initialize the newly created storage
167   /// instance, and used when the singleton instance is created.
168   template <typename Storage>
registerSingletonStorageType(TypeID id,function_ref<void (Storage *)> initFn)169   void registerSingletonStorageType(TypeID id,
170                                     function_ref<void(Storage *)> initFn) {
171     auto ctorFn = [&](StorageAllocator &allocator) {
172       auto *storage = new (allocator.allocate<Storage>()) Storage();
173       if (initFn)
174         initFn(storage);
175       return storage;
176     };
177     registerSingletonImpl(id, ctorFn);
178   }
179   template <typename Storage>
registerSingletonStorageType(TypeID id)180   void registerSingletonStorageType(TypeID id) {
181     registerSingletonStorageType<Storage>(id, std::nullopt);
182   }
183   /// Utility override when the storage type represents the type id.
184   template <typename Storage>
185   void registerSingletonStorageType(function_ref<void(Storage *)> initFn = {}) {
186     registerSingletonStorageType<Storage>(TypeID::get<Storage>(), initFn);
187   }
188 
189   /// Gets a uniqued instance of 'Storage'. 'id' is the type id used when
190   /// registering the storage instance. 'initFn' is an optional parameter that
191   /// can be used to initialize a newly inserted storage instance. This function
192   /// is used for derived types that have complex storage or uniquing
193   /// constraints.
194   template <typename Storage, typename... Args>
get(function_ref<void (Storage *)> initFn,TypeID id,Args &&...args)195   Storage *get(function_ref<void(Storage *)> initFn, TypeID id,
196                Args &&...args) {
197     // Construct a value of the derived key type.
198     auto derivedKey = getKey<Storage>(std::forward<Args>(args)...);
199 
200     // Create a hash of the derived key.
201     unsigned hashValue = getHash<Storage>(derivedKey);
202 
203     // Generate an equality function for the derived storage.
204     auto isEqual = [&derivedKey](const BaseStorage *existing) {
205       return static_cast<const Storage &>(*existing) == derivedKey;
206     };
207 
208     // Generate a constructor function for the derived storage.
209     auto ctorFn = [&](StorageAllocator &allocator) {
210       auto *storage = Storage::construct(allocator, std::move(derivedKey));
211       if (initFn)
212         initFn(storage);
213       return storage;
214     };
215 
216     // Get an instance for the derived storage.
217     return static_cast<Storage *>(
218         getParametricStorageTypeImpl(id, hashValue, isEqual, ctorFn));
219   }
220   /// Utility override when the storage type represents the type id.
221   template <typename Storage, typename... Args>
get(function_ref<void (Storage *)> initFn,Args &&...args)222   Storage *get(function_ref<void(Storage *)> initFn, Args &&...args) {
223     return get<Storage>(initFn, TypeID::get<Storage>(),
224                         std::forward<Args>(args)...);
225   }
226 
227   /// Gets a uniqued instance of 'Storage' which is a singleton storage type.
228   /// 'id' is the type id used when registering the storage instance.
229   template <typename Storage>
get(TypeID id)230   Storage *get(TypeID id) {
231     return static_cast<Storage *>(getSingletonImpl(id));
232   }
233   /// Utility override when the storage type represents the type id.
234   template <typename Storage>
get()235   Storage *get() {
236     return get<Storage>(TypeID::get<Storage>());
237   }
238 
239   /// Test if there is a singleton storage uniquer initialized for the provided
240   /// TypeID. This is only useful for debugging/diagnostic purpose: the uniquer
241   /// is initialized when a dialect is loaded.
242   bool isSingletonStorageInitialized(TypeID id);
243 
244   /// Test if there is a parametric storage uniquer initialized for the provided
245   /// TypeID. This is only useful for debugging/diagnostic purpose: the uniquer
246   /// is initialized when a dialect is loaded.
247   bool isParametricStorageInitialized(TypeID id);
248 
249   /// Changes the mutable component of 'storage' by forwarding the trailing
250   /// arguments to the 'mutate' function of the derived class.
251   template <typename Storage, typename... Args>
mutate(TypeID id,Storage * storage,Args &&...args)252   LogicalResult mutate(TypeID id, Storage *storage, Args &&...args) {
253     auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult {
254       return static_cast<Storage &>(*storage).mutate(
255           allocator, std::forward<Args>(args)...);
256     };
257     return mutateImpl(id, storage, mutationFn);
258   }
259 
260 private:
261   /// Implementation for getting/creating an instance of a derived type with
262   /// parametric storage.
263   BaseStorage *getParametricStorageTypeImpl(
264       TypeID id, unsigned hashValue,
265       function_ref<bool(const BaseStorage *)> isEqual,
266       function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
267 
268   /// Implementation for registering an instance of a derived type with
269   /// parametric storage. This method takes an optional destructor function that
270   /// destructs storage instances when necessary.
271   void registerParametricStorageTypeImpl(
272       TypeID id, function_ref<void(BaseStorage *)> destructorFn);
273 
274   /// Implementation for getting an instance of a derived type with default
275   /// storage.
276   BaseStorage *getSingletonImpl(TypeID id);
277 
278   /// Implementation for registering an instance of a derived type with default
279   /// storage.
280   void
281   registerSingletonImpl(TypeID id,
282                         function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
283 
284   /// Implementation for mutating an instance of a derived storage.
285   LogicalResult
286   mutateImpl(TypeID id, BaseStorage *storage,
287              function_ref<LogicalResult(StorageAllocator &)> mutationFn);
288 
289   /// The internal implementation class.
290   std::unique_ptr<detail::StorageUniquerImpl> impl;
291 
292   //===--------------------------------------------------------------------===//
293   // Key Construction
294   //===--------------------------------------------------------------------===//
295 
296   /// Used to construct an instance of 'ImplTy::KeyTy' if there is an
297   /// 'ImplTy::getKey' function for the provided arguments.  Otherwise, then we
298   /// try to directly construct the 'ImplTy::KeyTy' with the provided arguments.
299   template <typename ImplTy, typename... Args>
getKey(Args &&...args)300   static typename ImplTy::KeyTy getKey(Args &&...args) {
301     if constexpr (llvm::is_detected<detail::has_impltype_getkey_t, ImplTy,
302                                     Args...>::value)
303       return ImplTy::getKey(std::forward<Args>(args)...);
304     else
305       return typename ImplTy::KeyTy(std::forward<Args>(args)...);
306   }
307 
308   //===--------------------------------------------------------------------===//
309   // Key Hashing
310   //===--------------------------------------------------------------------===//
311 
312   /// Used to generate a hash for the `ImplTy` of a storage instance if
313   /// there is a `ImplTy::hashKey.  Otherwise, if there is no `ImplTy::hashKey`
314   /// then default to using the 'llvm::DenseMapInfo' definition for
315   /// 'DerivedKey' for generating a hash.
316   template <typename ImplTy, typename DerivedKey>
getHash(const DerivedKey & derivedKey)317   static ::llvm::hash_code getHash(const DerivedKey &derivedKey) {
318     if constexpr (llvm::is_detected<detail::has_impltype_hash_t, ImplTy,
319                                     DerivedKey>::value)
320       return ImplTy::hashKey(derivedKey);
321     else
322       return DenseMapInfo<DerivedKey>::getHashValue(derivedKey);
323   }
324 };
325 } // namespace mlir
326 
327 #endif
328