xref: /llvm-project/mlir/include/mlir/Support/ThreadLocalCache.h (revision ce2b488e90d6d5a5c8fe495ede8238938827da39)
1 //===- ThreadLocalCache.h - ThreadLocalCache class --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a definition of the ThreadLocalCache class. This class
10 // provides support for defining thread local objects with non-static duration.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_SUPPORT_THREADLOCALCACHE_H
15 #define MLIR_SUPPORT_THREADLOCALCACHE_H
16 
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/Support/ManagedStatic.h"
20 #include "llvm/Support/Mutex.h"
21 
22 namespace mlir {
23 /// This class provides support for defining a thread local object with non
24 /// static storage duration. This is very useful for situations in which a data
25 /// cache has very large lock contention.
26 template <typename ValueT>
27 class ThreadLocalCache {
28   struct PerInstanceState;
29 
30   using PointerAndFlag = std::pair<ValueT *, std::atomic<bool>>;
31 
32   /// The "observer" is owned by a thread-local cache instance. It is
33   /// constructed the first time a `ThreadLocalCache` instance is accessed by a
34   /// thread, unless `perInstanceState` happens to get re-allocated to the same
35   /// address as a previous one. A `thread_local` instance of this class is
36   /// destructed when the thread in which it lives is destroyed.
37   ///
38   /// This class is called the "observer" because while values cached in
39   /// thread-local caches are owned by `PerInstanceState`, a reference is stored
40   /// via this class in the TLC. With a double pointer, it knows when the
41   /// referenced value has been destroyed.
42   struct Observer {
43     /// This is the double pointer, explicitly allocated because we need to keep
44     /// the address stable if the TLC map re-allocates. It is owned by the
45     /// observer and shared with the value owner.
46     std::shared_ptr<PointerAndFlag> ptr =
47         std::make_shared<PointerAndFlag>(std::make_pair(nullptr, false));
48     /// Because the `Owner` instance that lives inside `PerInstanceState`
49     /// contains a reference to the double pointer, and likewise this class
50     /// contains a reference to the value, we need to synchronize destruction of
51     /// the TLC and the `PerInstanceState` to avoid racing. This weak pointer is
52     /// acquired during TLC destruction if the `PerInstanceState` hasn't entered
53     /// its destructor yet, and prevents it from happening.
54     std::weak_ptr<PerInstanceState> keepalive;
55   };
56 
57   /// This struct owns the cache entries. It contains a reference back to the
58   /// reference inside the cache so that it can be written to null to indicate
59   /// that the cache entry is invalidated. It needs to do this because
60   /// `perInstanceState` could get re-allocated to the same pointer and we don't
61   /// remove entries from the TLC when it is deallocated. Thus, we have to reset
62   /// the TLC entries to a starting state in case the `ThreadLocalCache` lives
63   /// shorter than the threads.
64   struct Owner {
65     /// Save a pointer to the reference and write it to the newly created entry.
66     Owner(Observer &observer)
67         : value(std::make_unique<ValueT>()), ptrRef(observer.ptr) {
68       observer.ptr->second = true;
69       observer.ptr->first = value.get();
70     }
71     ~Owner() {
72       if (std::shared_ptr<PointerAndFlag> ptr = ptrRef.lock()) {
73         ptr->first = nullptr;
74         ptr->second = false;
75       }
76     }
77 
78     Owner(Owner &&) = default;
79     Owner &operator=(Owner &&) = default;
80 
81     std::unique_ptr<ValueT> value;
82     std::weak_ptr<PointerAndFlag> ptrRef;
83   };
84 
85   // Keep a separate shared_ptr protected state that can be acquired atomically
86   // instead of using shared_ptr's for each value. This avoids a problem
87   // where the instance shared_ptr is locked() successfully, and then the
88   // ThreadLocalCache gets destroyed before remove() can be called successfully.
89   struct PerInstanceState {
90     /// Remove the given value entry. This is called when a thread local cache
91     /// is destructing but still contains references to values owned by the
92     /// `PerInstanceState`. Removal is required because it prevents writeback to
93     /// a pointer that was deallocated.
94     void remove(ValueT *value) {
95       // Erase the found value directly, because it is guaranteed to be in the
96       // list.
97       llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
98       auto it = llvm::find_if(instances, [&](Owner &instance) {
99         return instance.value.get() == value;
100       });
101       assert(it != instances.end() && "expected value to exist in cache");
102       instances.erase(it);
103     }
104 
105     /// Owning pointers to all of the values that have been constructed for this
106     /// object in the static cache.
107     SmallVector<Owner, 1> instances;
108 
109     /// A mutex used when a new thread instance has been added to the cache for
110     /// this object.
111     llvm::sys::SmartMutex<true> instanceMutex;
112   };
113 
114   /// The type used for the static thread_local cache. This is a map between an
115   /// instance of the non-static cache and a weak reference to an instance of
116   /// ValueT. We use a weak reference here so that the object can be destroyed
117   /// without needing to lock access to the cache itself.
118   struct CacheType : public llvm::SmallDenseMap<PerInstanceState *, Observer> {
119     ~CacheType() {
120       // Remove the values of this cache that haven't already expired. This is
121       // required because if we don't remove them, they will contain a reference
122       // back to the data here that is being destroyed.
123       for (auto &[instance, observer] : *this)
124         if (std::shared_ptr<PerInstanceState> state = observer.keepalive.lock())
125           state->remove(observer.ptr->first);
126     }
127 
128     /// Clear out any unused entries within the map. This method is not
129     /// thread-safe, and should only be called by the same thread as the cache.
130     void clearExpiredEntries() {
131       for (auto it = this->begin(), e = this->end(); it != e;) {
132         auto curIt = it++;
133         if (!curIt->second.ptr->second)
134           this->erase(curIt);
135       }
136     }
137   };
138 
139 public:
140   ThreadLocalCache() = default;
141   ~ThreadLocalCache() {
142     // No cleanup is necessary here as the shared_pointer memory will go out of
143     // scope and invalidate the weak pointers held by the thread_local caches.
144   }
145 
146   /// Return an instance of the value type for the current thread.
147   ValueT &get() {
148     // Check for an already existing instance for this thread.
149     CacheType &staticCache = getStaticCache();
150     Observer &threadInstance = staticCache[perInstanceState.get()];
151     if (ValueT *value = threadInstance.ptr->first)
152       return *value;
153 
154     // Otherwise, create a new instance for this thread.
155     {
156       llvm::sys::SmartScopedLock<true> threadInstanceLock(
157           perInstanceState->instanceMutex);
158       perInstanceState->instances.emplace_back(threadInstance);
159     }
160     threadInstance.keepalive = perInstanceState;
161 
162     // Before returning the new instance, take the chance to clear out any used
163     // entries in the static map. The cache is only cleared within the same
164     // thread to remove the need to lock the cache itself.
165     staticCache.clearExpiredEntries();
166     return *threadInstance.ptr->first;
167   }
168   ValueT &operator*() { return get(); }
169   ValueT *operator->() { return &get(); }
170 
171 private:
172   ThreadLocalCache(ThreadLocalCache &&) = delete;
173   ThreadLocalCache(const ThreadLocalCache &) = delete;
174   ThreadLocalCache &operator=(const ThreadLocalCache &) = delete;
175 
176   /// Return the static thread local instance of the cache type.
177   static CacheType &getStaticCache() {
178     static LLVM_THREAD_LOCAL CacheType cache;
179     return cache;
180   }
181 
182   std::shared_ptr<PerInstanceState> perInstanceState =
183       std::make_shared<PerInstanceState>();
184 };
185 } // namespace mlir
186 
187 #endif // MLIR_SUPPORT_THREADLOCALCACHE_H
188