1 //===- ThreadLocalCache.h - ThreadLocalCache class --------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a definition of the ThreadLocalCache class. This class 10 // provides support for defining thread local objects with non-static duration. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_SUPPORT_THREADLOCALCACHE_H 15 #define MLIR_SUPPORT_THREADLOCALCACHE_H 16 17 #include "mlir/Support/LLVM.h" 18 #include "llvm/ADT/DenseMap.h" 19 #include "llvm/Support/ManagedStatic.h" 20 #include "llvm/Support/Mutex.h" 21 22 namespace mlir { 23 /// This class provides support for defining a thread local object with non 24 /// static storage duration. This is very useful for situations in which a data 25 /// cache has very large lock contention. 26 template <typename ValueT> 27 class ThreadLocalCache { 28 struct PerInstanceState; 29 30 using PointerAndFlag = std::pair<ValueT *, std::atomic<bool>>; 31 32 /// The "observer" is owned by a thread-local cache instance. It is 33 /// constructed the first time a `ThreadLocalCache` instance is accessed by a 34 /// thread, unless `perInstanceState` happens to get re-allocated to the same 35 /// address as a previous one. A `thread_local` instance of this class is 36 /// destructed when the thread in which it lives is destroyed. 37 /// 38 /// This class is called the "observer" because while values cached in 39 /// thread-local caches are owned by `PerInstanceState`, a reference is stored 40 /// via this class in the TLC. With a double pointer, it knows when the 41 /// referenced value has been destroyed. 42 struct Observer { 43 /// This is the double pointer, explicitly allocated because we need to keep 44 /// the address stable if the TLC map re-allocates. It is owned by the 45 /// observer and shared with the value owner. 46 std::shared_ptr<PointerAndFlag> ptr = 47 std::make_shared<PointerAndFlag>(std::make_pair(nullptr, false)); 48 /// Because the `Owner` instance that lives inside `PerInstanceState` 49 /// contains a reference to the double pointer, and likewise this class 50 /// contains a reference to the value, we need to synchronize destruction of 51 /// the TLC and the `PerInstanceState` to avoid racing. This weak pointer is 52 /// acquired during TLC destruction if the `PerInstanceState` hasn't entered 53 /// its destructor yet, and prevents it from happening. 54 std::weak_ptr<PerInstanceState> keepalive; 55 }; 56 57 /// This struct owns the cache entries. It contains a reference back to the 58 /// reference inside the cache so that it can be written to null to indicate 59 /// that the cache entry is invalidated. It needs to do this because 60 /// `perInstanceState` could get re-allocated to the same pointer and we don't 61 /// remove entries from the TLC when it is deallocated. Thus, we have to reset 62 /// the TLC entries to a starting state in case the `ThreadLocalCache` lives 63 /// shorter than the threads. 64 struct Owner { 65 /// Save a pointer to the reference and write it to the newly created entry. 66 Owner(Observer &observer) 67 : value(std::make_unique<ValueT>()), ptrRef(observer.ptr) { 68 observer.ptr->second = true; 69 observer.ptr->first = value.get(); 70 } 71 ~Owner() { 72 if (std::shared_ptr<PointerAndFlag> ptr = ptrRef.lock()) { 73 ptr->first = nullptr; 74 ptr->second = false; 75 } 76 } 77 78 Owner(Owner &&) = default; 79 Owner &operator=(Owner &&) = default; 80 81 std::unique_ptr<ValueT> value; 82 std::weak_ptr<PointerAndFlag> ptrRef; 83 }; 84 85 // Keep a separate shared_ptr protected state that can be acquired atomically 86 // instead of using shared_ptr's for each value. This avoids a problem 87 // where the instance shared_ptr is locked() successfully, and then the 88 // ThreadLocalCache gets destroyed before remove() can be called successfully. 89 struct PerInstanceState { 90 /// Remove the given value entry. This is called when a thread local cache 91 /// is destructing but still contains references to values owned by the 92 /// `PerInstanceState`. Removal is required because it prevents writeback to 93 /// a pointer that was deallocated. 94 void remove(ValueT *value) { 95 // Erase the found value directly, because it is guaranteed to be in the 96 // list. 97 llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex); 98 auto it = llvm::find_if(instances, [&](Owner &instance) { 99 return instance.value.get() == value; 100 }); 101 assert(it != instances.end() && "expected value to exist in cache"); 102 instances.erase(it); 103 } 104 105 /// Owning pointers to all of the values that have been constructed for this 106 /// object in the static cache. 107 SmallVector<Owner, 1> instances; 108 109 /// A mutex used when a new thread instance has been added to the cache for 110 /// this object. 111 llvm::sys::SmartMutex<true> instanceMutex; 112 }; 113 114 /// The type used for the static thread_local cache. This is a map between an 115 /// instance of the non-static cache and a weak reference to an instance of 116 /// ValueT. We use a weak reference here so that the object can be destroyed 117 /// without needing to lock access to the cache itself. 118 struct CacheType : public llvm::SmallDenseMap<PerInstanceState *, Observer> { 119 ~CacheType() { 120 // Remove the values of this cache that haven't already expired. This is 121 // required because if we don't remove them, they will contain a reference 122 // back to the data here that is being destroyed. 123 for (auto &[instance, observer] : *this) 124 if (std::shared_ptr<PerInstanceState> state = observer.keepalive.lock()) 125 state->remove(observer.ptr->first); 126 } 127 128 /// Clear out any unused entries within the map. This method is not 129 /// thread-safe, and should only be called by the same thread as the cache. 130 void clearExpiredEntries() { 131 for (auto it = this->begin(), e = this->end(); it != e;) { 132 auto curIt = it++; 133 if (!curIt->second.ptr->second) 134 this->erase(curIt); 135 } 136 } 137 }; 138 139 public: 140 ThreadLocalCache() = default; 141 ~ThreadLocalCache() { 142 // No cleanup is necessary here as the shared_pointer memory will go out of 143 // scope and invalidate the weak pointers held by the thread_local caches. 144 } 145 146 /// Return an instance of the value type for the current thread. 147 ValueT &get() { 148 // Check for an already existing instance for this thread. 149 CacheType &staticCache = getStaticCache(); 150 Observer &threadInstance = staticCache[perInstanceState.get()]; 151 if (ValueT *value = threadInstance.ptr->first) 152 return *value; 153 154 // Otherwise, create a new instance for this thread. 155 { 156 llvm::sys::SmartScopedLock<true> threadInstanceLock( 157 perInstanceState->instanceMutex); 158 perInstanceState->instances.emplace_back(threadInstance); 159 } 160 threadInstance.keepalive = perInstanceState; 161 162 // Before returning the new instance, take the chance to clear out any used 163 // entries in the static map. The cache is only cleared within the same 164 // thread to remove the need to lock the cache itself. 165 staticCache.clearExpiredEntries(); 166 return *threadInstance.ptr->first; 167 } 168 ValueT &operator*() { return get(); } 169 ValueT *operator->() { return &get(); } 170 171 private: 172 ThreadLocalCache(ThreadLocalCache &&) = delete; 173 ThreadLocalCache(const ThreadLocalCache &) = delete; 174 ThreadLocalCache &operator=(const ThreadLocalCache &) = delete; 175 176 /// Return the static thread local instance of the cache type. 177 static CacheType &getStaticCache() { 178 static LLVM_THREAD_LOCAL CacheType cache; 179 return cache; 180 } 181 182 std::shared_ptr<PerInstanceState> perInstanceState = 183 std::make_shared<PerInstanceState>(); 184 }; 185 } // namespace mlir 186 187 #endif // MLIR_SUPPORT_THREADLOCALCACHE_H 188