xref: /llvm-project/offload/plugins-nextgen/common/include/MemoryManager.h (revision ed512710a5e855a029a05f399335e03db0e704bd)
1 //===----------- MemoryManager.h - Target independent memory manager ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Target independent memory manager.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_H
14 #define LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_H
15 
16 #include <cassert>
17 #include <functional>
18 #include <list>
19 #include <mutex>
20 #include <set>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "Shared/Debug.h"
25 #include "Shared/Utils.h"
26 #include "omptarget.h"
27 
28 /// Base class of per-device allocator.
29 class DeviceAllocatorTy {
30 public:
31   virtual ~DeviceAllocatorTy() = default;
32 
33   /// Allocate a memory of size \p Size . \p HstPtr is used to assist the
34   /// allocation.
35   virtual void *allocate(size_t Size, void *HstPtr,
36                          TargetAllocTy Kind = TARGET_ALLOC_DEFAULT) = 0;
37 
38   /// Delete the pointer \p TgtPtr on the device
39   virtual int free(void *TgtPtr, TargetAllocTy Kind = TARGET_ALLOC_DEFAULT) = 0;
40 };
41 
42 /// Class of memory manager. The memory manager is per-device by using
43 /// per-device allocator. Therefore, each plugin using memory manager should
44 /// have an allocator for each device.
45 class MemoryManagerTy {
46   static constexpr const size_t BucketSize[] = {
47       0,       1U << 2, 1U << 3,  1U << 4,  1U << 5,  1U << 6, 1U << 7,
48       1U << 8, 1U << 9, 1U << 10, 1U << 11, 1U << 12, 1U << 13};
49 
50   static constexpr const int NumBuckets =
51       sizeof(BucketSize) / sizeof(BucketSize[0]);
52 
53   /// Find the previous number that is power of 2 given a number that is not
54   /// power of 2.
55   static size_t floorToPowerOfTwo(size_t Num) {
56     Num |= Num >> 1;
57     Num |= Num >> 2;
58     Num |= Num >> 4;
59     Num |= Num >> 8;
60     Num |= Num >> 16;
61 #if INTPTR_MAX == INT64_MAX
62     Num |= Num >> 32;
63 #elif INTPTR_MAX == INT32_MAX
64     // Do nothing with 32-bit
65 #else
66 #error Unsupported architecture
67 #endif
68     Num += 1;
69     return Num >> 1;
70   }
71 
72   /// Find a suitable bucket
73   static int findBucket(size_t Size) {
74     const size_t F = floorToPowerOfTwo(Size);
75 
76     DP("findBucket: Size %zu is floored to %zu.\n", Size, F);
77 
78     int L = 0, H = NumBuckets - 1;
79     while (H - L > 1) {
80       int M = (L + H) >> 1;
81       if (BucketSize[M] == F)
82         return M;
83       if (BucketSize[M] > F)
84         H = M - 1;
85       else
86         L = M;
87     }
88 
89     assert(L >= 0 && L < NumBuckets && "L is out of range");
90 
91     DP("findBucket: Size %zu goes to bucket %d\n", Size, L);
92 
93     return L;
94   }
95 
96   /// A structure stores the meta data of a target pointer
97   struct NodeTy {
98     /// Memory size
99     const size_t Size;
100     /// Target pointer
101     void *Ptr;
102 
103     /// Constructor
104     NodeTy(size_t Size, void *Ptr) : Size(Size), Ptr(Ptr) {}
105   };
106 
107   /// To make \p NodePtrTy ordered when they're put into \p std::multiset.
108   struct NodeCmpTy {
109     bool operator()(const NodeTy &LHS, const NodeTy &RHS) const {
110       return LHS.Size < RHS.Size;
111     }
112   };
113 
114   /// A \p FreeList is a set of Nodes. We're using \p std::multiset here to make
115   /// the look up procedure more efficient.
116   using FreeListTy = std::multiset<std::reference_wrapper<NodeTy>, NodeCmpTy>;
117 
118   /// A list of \p FreeListTy entries, each of which is a \p std::multiset of
119   /// Nodes whose size is less or equal to a specific bucket size.
120   std::vector<FreeListTy> FreeLists;
121   /// A list of mutex for each \p FreeListTy entry
122   std::vector<std::mutex> FreeListLocks;
123   /// A table to map from a target pointer to its node
124   std::unordered_map<void *, NodeTy> PtrToNodeTable;
125   /// The mutex for the table \p PtrToNodeTable
126   std::mutex MapTableLock;
127 
128   /// The reference to a device allocator
129   DeviceAllocatorTy &DeviceAllocator;
130 
131   /// The threshold to manage memory using memory manager. If the request size
132   /// is larger than \p SizeThreshold, the allocation will not be managed by the
133   /// memory manager.
134   size_t SizeThreshold = 1U << 13;
135 
136   /// Request memory from target device
137   void *allocateOnDevice(size_t Size, void *HstPtr) const {
138     return DeviceAllocator.allocate(Size, HstPtr, TARGET_ALLOC_DEVICE);
139   }
140 
141   /// Deallocate data on device
142   int deleteOnDevice(void *Ptr) const { return DeviceAllocator.free(Ptr); }
143 
144   /// This function is called when it tries to allocate memory on device but the
145   /// device returns out of memory. It will first free all memory in the
146   /// FreeList and try to allocate again.
147   void *freeAndAllocate(size_t Size, void *HstPtr) {
148     std::vector<void *> RemoveList;
149 
150     // Deallocate all memory in FreeList
151     for (int I = 0; I < NumBuckets; ++I) {
152       FreeListTy &List = FreeLists[I];
153       std::lock_guard<std::mutex> Lock(FreeListLocks[I]);
154       if (List.empty())
155         continue;
156       for (const NodeTy &N : List) {
157         deleteOnDevice(N.Ptr);
158         RemoveList.push_back(N.Ptr);
159       }
160       FreeLists[I].clear();
161     }
162 
163     // Remove all nodes in the map table which have been released
164     if (!RemoveList.empty()) {
165       std::lock_guard<std::mutex> LG(MapTableLock);
166       for (void *P : RemoveList)
167         PtrToNodeTable.erase(P);
168     }
169 
170     // Try allocate memory again
171     return allocateOnDevice(Size, HstPtr);
172   }
173 
174   /// The goal is to allocate memory on the device. It first tries to
175   /// allocate directly on the device. If a \p nullptr is returned, it might
176   /// be because the device is OOM. In that case, it will free all unused
177   /// memory and then try again.
178   void *allocateOrFreeAndAllocateOnDevice(size_t Size, void *HstPtr) {
179     void *TgtPtr = allocateOnDevice(Size, HstPtr);
180     // We cannot get memory from the device. It might be due to OOM. Let's
181     // free all memory in FreeLists and try again.
182     if (TgtPtr == nullptr) {
183       DP("Failed to get memory on device. Free all memory in FreeLists and "
184          "try again.\n");
185       TgtPtr = freeAndAllocate(Size, HstPtr);
186     }
187 
188     if (TgtPtr == nullptr)
189       DP("Still cannot get memory on device probably because the device is "
190          "OOM.\n");
191 
192     return TgtPtr;
193   }
194 
195 public:
196   /// Constructor. If \p Threshold is non-zero, then the default threshold will
197   /// be overwritten by \p Threshold.
198   MemoryManagerTy(DeviceAllocatorTy &DeviceAllocator, size_t Threshold = 0)
199       : FreeLists(NumBuckets), FreeListLocks(NumBuckets),
200         DeviceAllocator(DeviceAllocator) {
201     if (Threshold)
202       SizeThreshold = Threshold;
203   }
204 
205   /// Destructor
206   ~MemoryManagerTy() {
207     for (auto Itr = PtrToNodeTable.begin(); Itr != PtrToNodeTable.end();
208          ++Itr) {
209       assert(Itr->second.Ptr && "nullptr in map table");
210       deleteOnDevice(Itr->second.Ptr);
211     }
212   }
213 
214   /// Allocate memory of size \p Size from target device. \p HstPtr is used to
215   /// assist the allocation.
216   void *allocate(size_t Size, void *HstPtr) {
217     // If the size is zero, we will not bother the target device. Just return
218     // nullptr directly.
219     if (Size == 0)
220       return nullptr;
221 
222     DP("MemoryManagerTy::allocate: size %zu with host pointer " DPxMOD ".\n",
223        Size, DPxPTR(HstPtr));
224 
225     // If the size is greater than the threshold, allocate it directly from
226     // device.
227     if (Size > SizeThreshold) {
228       DP("%zu is greater than the threshold %zu. Allocate it directly from "
229          "device\n",
230          Size, SizeThreshold);
231       void *TgtPtr = allocateOrFreeAndAllocateOnDevice(Size, HstPtr);
232 
233       DP("Got target pointer " DPxMOD ". Return directly.\n", DPxPTR(TgtPtr));
234 
235       return TgtPtr;
236     }
237 
238     NodeTy *NodePtr = nullptr;
239 
240     // Try to get a node from FreeList
241     {
242       const int B = findBucket(Size);
243       FreeListTy &List = FreeLists[B];
244 
245       NodeTy TempNode(Size, nullptr);
246       std::lock_guard<std::mutex> LG(FreeListLocks[B]);
247       const auto Itr = List.find(TempNode);
248 
249       if (Itr != List.end()) {
250         NodePtr = &Itr->get();
251         List.erase(Itr);
252       }
253     }
254 
255     if (NodePtr != nullptr)
256       DP("Find one node " DPxMOD " in the bucket.\n", DPxPTR(NodePtr));
257 
258     // We cannot find a valid node in FreeLists. Let's allocate on device and
259     // create a node for it.
260     if (NodePtr == nullptr) {
261       DP("Cannot find a node in the FreeLists. Allocate on device.\n");
262       // Allocate one on device
263       void *TgtPtr = allocateOrFreeAndAllocateOnDevice(Size, HstPtr);
264 
265       if (TgtPtr == nullptr)
266         return nullptr;
267 
268       // Create a new node and add it into the map table
269       {
270         std::lock_guard<std::mutex> Guard(MapTableLock);
271         auto Itr = PtrToNodeTable.emplace(TgtPtr, NodeTy(Size, TgtPtr));
272         NodePtr = &Itr.first->second;
273       }
274 
275       DP("Node address " DPxMOD ", target pointer " DPxMOD ", size %zu\n",
276          DPxPTR(NodePtr), DPxPTR(TgtPtr), Size);
277     }
278 
279     assert(NodePtr && "NodePtr should not be nullptr at this point");
280 
281     return NodePtr->Ptr;
282   }
283 
284   /// Deallocate memory pointed by \p TgtPtr
285   int free(void *TgtPtr) {
286     DP("MemoryManagerTy::free: target memory " DPxMOD ".\n", DPxPTR(TgtPtr));
287 
288     NodeTy *P = nullptr;
289 
290     // Look it up into the table
291     {
292       std::lock_guard<std::mutex> G(MapTableLock);
293       auto Itr = PtrToNodeTable.find(TgtPtr);
294 
295       // We don't remove the node from the map table because the map does not
296       // change.
297       if (Itr != PtrToNodeTable.end())
298         P = &Itr->second;
299     }
300 
301     // The memory is not managed by the manager
302     if (P == nullptr) {
303       DP("Cannot find its node. Delete it on device directly.\n");
304       return deleteOnDevice(TgtPtr);
305     }
306 
307     // Insert the node to the free list
308     const int B = findBucket(P->Size);
309 
310     DP("Found its node " DPxMOD ". Insert it to bucket %d.\n", DPxPTR(P), B);
311 
312     {
313       std::lock_guard<std::mutex> G(FreeListLocks[B]);
314       FreeLists[B].insert(*P);
315     }
316 
317     return OFFLOAD_SUCCESS;
318   }
319 
320   /// Get the size threshold from the environment variable
321   /// \p LIBOMPTARGET_MEMORY_MANAGER_THRESHOLD . Returns a <tt>
322   /// std::pair<size_t, bool> </tt> where the first element represents the
323   /// threshold and the second element represents whether user disables memory
324   /// manager explicitly by setting the var to 0. If user doesn't specify
325   /// anything, returns <0, true>.
326   static std::pair<size_t, bool> getSizeThresholdFromEnv() {
327     static UInt64Envar MemoryManagerThreshold(
328         "LIBOMPTARGET_MEMORY_MANAGER_THRESHOLD", 0);
329 
330     size_t Threshold = MemoryManagerThreshold.get();
331 
332     if (MemoryManagerThreshold.isPresent() && Threshold == 0) {
333       DP("Disabled memory manager as user set "
334          "LIBOMPTARGET_MEMORY_MANAGER_THRESHOLD=0.\n");
335       return std::make_pair(0, false);
336     }
337 
338     return std::make_pair(Threshold, true);
339   }
340 };
341 
342 // GCC still cannot handle the static data member like Clang so we still need
343 // this part.
344 constexpr const size_t MemoryManagerTy::BucketSize[];
345 constexpr const int MemoryManagerTy::NumBuckets;
346 
347 #endif // LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_H
348