1 //===- StorageUniquer.h - Common Storage Class Uniquer ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_SUPPORT_STORAGEUNIQUER_H 10 #define MLIR_SUPPORT_STORAGEUNIQUER_H 11 12 #include "mlir/Support/LLVM.h" 13 #include "mlir/Support/TypeID.h" 14 #include "llvm/ADT/ArrayRef.h" 15 #include "llvm/ADT/DenseSet.h" 16 #include "llvm/ADT/StringRef.h" 17 #include "llvm/Support/Allocator.h" 18 #include <utility> 19 20 namespace mlir { 21 namespace detail { 22 struct StorageUniquerImpl; 23 24 /// Trait to check if ImplTy provides a 'getKey' method with types 'Args'. 25 template <typename ImplTy, typename... Args> 26 using has_impltype_getkey_t = decltype(ImplTy::getKey(std::declval<Args>()...)); 27 28 /// Trait to check if ImplTy provides a 'hashKey' method for 'T'. 29 template <typename ImplTy, typename T> 30 using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>())); 31 } // namespace detail 32 33 /// A utility class to get or create instances of "storage classes". These 34 /// storage classes must derive from 'StorageUniquer::BaseStorage'. 35 /// 36 /// For non-parametric storage classes, i.e. singleton classes, nothing else is 37 /// needed. Instances of these classes can be created by calling `get` without 38 /// trailing arguments. 39 /// 40 /// Otherwise, the parametric storage classes may be created with `get`, 41 /// and must respect the following: 42 /// - Define a type alias, KeyTy, to a type that uniquely identifies the 43 /// instance of the storage class. 44 /// * The key type must be constructible from the values passed into the 45 /// getComplex call. 46 /// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the 47 /// storage class must define a hashing method: 48 /// 'static unsigned hashKey(const KeyTy &)' 49 /// 50 /// - Provide a method, 'bool operator==(const KeyTy &) const', to 51 /// compare the storage instance against an instance of the key type. 52 /// 53 /// - Provide a static construction method: 54 /// 'DerivedStorage *construct(StorageAllocator &, const KeyTy &key)' 55 /// that builds a unique instance of the derived storage. The arguments to 56 /// this function are an allocator to store any uniqued data and the key 57 /// type for this storage. 58 /// 59 /// - Provide a cleanup method: 60 /// 'void cleanup()' 61 /// that is called when erasing a storage instance. This should cleanup any 62 /// fields of the storage as necessary and not attempt to free the memory 63 /// of the storage itself. 64 /// 65 /// Storage classes may have an optional mutable component, which must not take 66 /// part in the unique immutable key. In this case, storage classes may be 67 /// mutated with `mutate` and must additionally respect the following: 68 /// - Provide a mutation method: 69 /// 'LogicalResult mutate(StorageAllocator &, <...>)' 70 /// that is called when mutating a storage instance. The first argument is 71 /// an allocator to store any mutable data, and the remaining arguments are 72 /// forwarded from the call site. The storage can be mutated at any time 73 /// after creation. Care must be taken to avoid excessive mutation since 74 /// the allocated storage can keep containing previous states. The return 75 /// value of the function is used to indicate whether the mutation was 76 /// successful, e.g., to limit the number of mutations or enable deferred 77 /// one-time assignment of the mutable component. 78 /// 79 /// All storage classes must be registered with the uniquer via 80 /// `registerParametricStorageType` or `registerSingletonStorageType` 81 /// using an appropriate unique `TypeID` for the storage class. 82 class StorageUniquer { 83 public: 84 /// This class acts as the base storage that all storage classes must derived 85 /// from. 86 class alignas(8) BaseStorage { 87 protected: 88 BaseStorage() = default; 89 }; 90 91 /// This is a utility allocator used to allocate memory for instances of 92 /// derived types. 93 class StorageAllocator { 94 public: 95 /// Copy the specified array of elements into memory managed by our bump 96 /// pointer allocator. This assumes the elements are all PODs. 97 template <typename T> copyInto(ArrayRef<T> elements)98 ArrayRef<T> copyInto(ArrayRef<T> elements) { 99 if (elements.empty()) 100 return std::nullopt; 101 auto result = allocator.Allocate<T>(elements.size()); 102 std::uninitialized_copy(elements.begin(), elements.end(), result); 103 return ArrayRef<T>(result, elements.size()); 104 } 105 106 /// Copy the provided string into memory managed by our bump pointer 107 /// allocator. copyInto(StringRef str)108 StringRef copyInto(StringRef str) { 109 if (str.empty()) 110 return StringRef(); 111 112 char *result = allocator.Allocate<char>(str.size() + 1); 113 std::uninitialized_copy(str.begin(), str.end(), result); 114 result[str.size()] = 0; 115 return StringRef(result, str.size()); 116 } 117 118 /// Allocate an instance of the provided type. 119 template <typename T> allocate()120 T *allocate() { 121 return allocator.Allocate<T>(); 122 } 123 124 /// Allocate 'size' bytes of 'alignment' aligned memory. allocate(size_t size,size_t alignment)125 void *allocate(size_t size, size_t alignment) { 126 return allocator.Allocate(size, alignment); 127 } 128 129 /// Returns true if this allocator allocated the provided object pointer. allocated(const void * ptr)130 bool allocated(const void *ptr) { 131 return allocator.identifyObject(ptr).has_value(); 132 } 133 134 private: 135 /// The raw allocator for type storage objects. 136 llvm::BumpPtrAllocator allocator; 137 }; 138 139 StorageUniquer(); 140 ~StorageUniquer(); 141 142 /// Set the flag specifying if multi-threading is disabled within the uniquer. 143 void disableMultithreading(bool disable = true); 144 145 /// Register a new parametric storage class, this is necessary to create 146 /// instances of this class type. `id` is the type identifier that will be 147 /// used to identify this type when creating instances of it via 'get'. 148 template <typename Storage> registerParametricStorageType(TypeID id)149 void registerParametricStorageType(TypeID id) { 150 // If the storage is trivially destructible, we don't need a destructor 151 // function. 152 if constexpr (std::is_trivially_destructible_v<Storage>) 153 return registerParametricStorageTypeImpl(id, nullptr); 154 registerParametricStorageTypeImpl(id, [](BaseStorage *storage) { 155 static_cast<Storage *>(storage)->~Storage(); 156 }); 157 } 158 /// Utility override when the storage type represents the type id. 159 template <typename Storage> registerParametricStorageType()160 void registerParametricStorageType() { 161 registerParametricStorageType<Storage>(TypeID::get<Storage>()); 162 } 163 /// Register a new singleton storage class, this is necessary to get the 164 /// singletone instance. `id` is the type identifier that will be used to 165 /// access the singleton instance via 'get'. An optional initialization 166 /// function may also be provided to initialize the newly created storage 167 /// instance, and used when the singleton instance is created. 168 template <typename Storage> registerSingletonStorageType(TypeID id,function_ref<void (Storage *)> initFn)169 void registerSingletonStorageType(TypeID id, 170 function_ref<void(Storage *)> initFn) { 171 auto ctorFn = [&](StorageAllocator &allocator) { 172 auto *storage = new (allocator.allocate<Storage>()) Storage(); 173 if (initFn) 174 initFn(storage); 175 return storage; 176 }; 177 registerSingletonImpl(id, ctorFn); 178 } 179 template <typename Storage> registerSingletonStorageType(TypeID id)180 void registerSingletonStorageType(TypeID id) { 181 registerSingletonStorageType<Storage>(id, std::nullopt); 182 } 183 /// Utility override when the storage type represents the type id. 184 template <typename Storage> 185 void registerSingletonStorageType(function_ref<void(Storage *)> initFn = {}) { 186 registerSingletonStorageType<Storage>(TypeID::get<Storage>(), initFn); 187 } 188 189 /// Gets a uniqued instance of 'Storage'. 'id' is the type id used when 190 /// registering the storage instance. 'initFn' is an optional parameter that 191 /// can be used to initialize a newly inserted storage instance. This function 192 /// is used for derived types that have complex storage or uniquing 193 /// constraints. 194 template <typename Storage, typename... Args> get(function_ref<void (Storage *)> initFn,TypeID id,Args &&...args)195 Storage *get(function_ref<void(Storage *)> initFn, TypeID id, 196 Args &&...args) { 197 // Construct a value of the derived key type. 198 auto derivedKey = getKey<Storage>(std::forward<Args>(args)...); 199 200 // Create a hash of the derived key. 201 unsigned hashValue = getHash<Storage>(derivedKey); 202 203 // Generate an equality function for the derived storage. 204 auto isEqual = [&derivedKey](const BaseStorage *existing) { 205 return static_cast<const Storage &>(*existing) == derivedKey; 206 }; 207 208 // Generate a constructor function for the derived storage. 209 auto ctorFn = [&](StorageAllocator &allocator) { 210 auto *storage = Storage::construct(allocator, std::move(derivedKey)); 211 if (initFn) 212 initFn(storage); 213 return storage; 214 }; 215 216 // Get an instance for the derived storage. 217 return static_cast<Storage *>( 218 getParametricStorageTypeImpl(id, hashValue, isEqual, ctorFn)); 219 } 220 /// Utility override when the storage type represents the type id. 221 template <typename Storage, typename... Args> get(function_ref<void (Storage *)> initFn,Args &&...args)222 Storage *get(function_ref<void(Storage *)> initFn, Args &&...args) { 223 return get<Storage>(initFn, TypeID::get<Storage>(), 224 std::forward<Args>(args)...); 225 } 226 227 /// Gets a uniqued instance of 'Storage' which is a singleton storage type. 228 /// 'id' is the type id used when registering the storage instance. 229 template <typename Storage> get(TypeID id)230 Storage *get(TypeID id) { 231 return static_cast<Storage *>(getSingletonImpl(id)); 232 } 233 /// Utility override when the storage type represents the type id. 234 template <typename Storage> get()235 Storage *get() { 236 return get<Storage>(TypeID::get<Storage>()); 237 } 238 239 /// Test if there is a singleton storage uniquer initialized for the provided 240 /// TypeID. This is only useful for debugging/diagnostic purpose: the uniquer 241 /// is initialized when a dialect is loaded. 242 bool isSingletonStorageInitialized(TypeID id); 243 244 /// Test if there is a parametric storage uniquer initialized for the provided 245 /// TypeID. This is only useful for debugging/diagnostic purpose: the uniquer 246 /// is initialized when a dialect is loaded. 247 bool isParametricStorageInitialized(TypeID id); 248 249 /// Changes the mutable component of 'storage' by forwarding the trailing 250 /// arguments to the 'mutate' function of the derived class. 251 template <typename Storage, typename... Args> mutate(TypeID id,Storage * storage,Args &&...args)252 LogicalResult mutate(TypeID id, Storage *storage, Args &&...args) { 253 auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult { 254 return static_cast<Storage &>(*storage).mutate( 255 allocator, std::forward<Args>(args)...); 256 }; 257 return mutateImpl(id, storage, mutationFn); 258 } 259 260 private: 261 /// Implementation for getting/creating an instance of a derived type with 262 /// parametric storage. 263 BaseStorage *getParametricStorageTypeImpl( 264 TypeID id, unsigned hashValue, 265 function_ref<bool(const BaseStorage *)> isEqual, 266 function_ref<BaseStorage *(StorageAllocator &)> ctorFn); 267 268 /// Implementation for registering an instance of a derived type with 269 /// parametric storage. This method takes an optional destructor function that 270 /// destructs storage instances when necessary. 271 void registerParametricStorageTypeImpl( 272 TypeID id, function_ref<void(BaseStorage *)> destructorFn); 273 274 /// Implementation for getting an instance of a derived type with default 275 /// storage. 276 BaseStorage *getSingletonImpl(TypeID id); 277 278 /// Implementation for registering an instance of a derived type with default 279 /// storage. 280 void 281 registerSingletonImpl(TypeID id, 282 function_ref<BaseStorage *(StorageAllocator &)> ctorFn); 283 284 /// Implementation for mutating an instance of a derived storage. 285 LogicalResult 286 mutateImpl(TypeID id, BaseStorage *storage, 287 function_ref<LogicalResult(StorageAllocator &)> mutationFn); 288 289 /// The internal implementation class. 290 std::unique_ptr<detail::StorageUniquerImpl> impl; 291 292 //===--------------------------------------------------------------------===// 293 // Key Construction 294 //===--------------------------------------------------------------------===// 295 296 /// Used to construct an instance of 'ImplTy::KeyTy' if there is an 297 /// 'ImplTy::getKey' function for the provided arguments. Otherwise, then we 298 /// try to directly construct the 'ImplTy::KeyTy' with the provided arguments. 299 template <typename ImplTy, typename... Args> getKey(Args &&...args)300 static typename ImplTy::KeyTy getKey(Args &&...args) { 301 if constexpr (llvm::is_detected<detail::has_impltype_getkey_t, ImplTy, 302 Args...>::value) 303 return ImplTy::getKey(std::forward<Args>(args)...); 304 else 305 return typename ImplTy::KeyTy(std::forward<Args>(args)...); 306 } 307 308 //===--------------------------------------------------------------------===// 309 // Key Hashing 310 //===--------------------------------------------------------------------===// 311 312 /// Used to generate a hash for the `ImplTy` of a storage instance if 313 /// there is a `ImplTy::hashKey. Otherwise, if there is no `ImplTy::hashKey` 314 /// then default to using the 'llvm::DenseMapInfo' definition for 315 /// 'DerivedKey' for generating a hash. 316 template <typename ImplTy, typename DerivedKey> getHash(const DerivedKey & derivedKey)317 static ::llvm::hash_code getHash(const DerivedKey &derivedKey) { 318 if constexpr (llvm::is_detected<detail::has_impltype_hash_t, ImplTy, 319 DerivedKey>::value) 320 return ImplTy::hashKey(derivedKey); 321 else 322 return DenseMapInfo<DerivedKey>::getHashValue(derivedKey); 323 } 324 }; 325 } // namespace mlir 326 327 #endif 328