xref: /llvm-project/llvm/lib/Support/TrieRawHashMap.cpp (revision 03948882d3bac33cf71a47df1c7ee0f87aad9fc2)
1 //===- TrieRawHashMap.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ADT/TrieRawHashMap.h"
10 #include "llvm/ADT/LazyAtomicPointer.h"
11 #include "llvm/ADT/StringExtras.h"
12 #include "llvm/ADT/TrieHashIndexGenerator.h"
13 #include "llvm/Support/Allocator.h"
14 #include "llvm/Support/Casting.h"
15 #include "llvm/Support/Debug.h"
16 #include "llvm/Support/ThreadSafeAllocator.h"
17 #include "llvm/Support/TrailingObjects.h"
18 #include "llvm/Support/raw_ostream.h"
19 #include <memory>
20 
21 using namespace llvm;
22 
23 namespace {
24 struct TrieNode {
25   const bool IsSubtrie = false;
26 
27   TrieNode(bool IsSubtrie) : IsSubtrie(IsSubtrie) {}
28 
29   static void *operator new(size_t Size) { return ::operator new(Size); }
30   void operator delete(void *Ptr) { ::operator delete(Ptr); }
31 };
32 
33 struct TrieContent final : public TrieNode {
34   const uint8_t ContentOffset;
35   const uint8_t HashSize;
36   const uint8_t HashOffset;
37 
38   void *getValuePointer() const {
39     auto *Content = reinterpret_cast<const uint8_t *>(this) + ContentOffset;
40     return const_cast<uint8_t *>(Content);
41   }
42 
43   ArrayRef<uint8_t> getHash() const {
44     auto *Begin = reinterpret_cast<const uint8_t *>(this) + HashOffset;
45     return ArrayRef(Begin, Begin + HashSize);
46   }
47 
48   TrieContent(size_t ContentOffset, size_t HashSize, size_t HashOffset)
49       : TrieNode(/*IsSubtrie=*/false), ContentOffset(ContentOffset),
50         HashSize(HashSize), HashOffset(HashOffset) {}
51 
52   static bool classof(const TrieNode *TN) { return !TN->IsSubtrie; }
53 };
54 
55 static_assert(sizeof(TrieContent) ==
56                   ThreadSafeTrieRawHashMapBase::TrieContentBaseSize,
57               "Check header assumption!");
58 
59 class TrieSubtrie final
60     : public TrieNode,
61       private TrailingObjects<TrieSubtrie, LazyAtomicPointer<TrieNode>> {
62 public:
63   using Slot = LazyAtomicPointer<TrieNode>;
64 
65   Slot &get(size_t I) { return getTrailingObjects<Slot>()[I]; }
66   TrieNode *load(size_t I) { return get(I).load(); }
67 
68   unsigned size() const { return Size; }
69 
70   TrieSubtrie *
71   sink(size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI,
72        function_ref<TrieSubtrie *(std::unique_ptr<TrieSubtrie>)> Saver);
73 
74   static std::unique_ptr<TrieSubtrie> create(size_t StartBit, size_t NumBits);
75 
76   explicit TrieSubtrie(size_t StartBit, size_t NumBits);
77 
78   static bool classof(const TrieNode *TN) { return TN->IsSubtrie; }
79 
80   static constexpr size_t sizeToAlloc(unsigned NumBits) {
81     assert(NumBits < 20 && "Tries should have fewer than ~1M slots");
82     unsigned Count = 1u << NumBits;
83     return totalSizeToAlloc<LazyAtomicPointer<TrieNode>>(Count);
84   }
85 
86 private:
87   // FIXME: Use a bitset to speed up access:
88   //
89   //     std::array<std::atomic<uint64_t>, NumSlots/64> IsSet;
90   //
91   // This will avoid needing to visit sparsely filled slots in
92   // \a ThreadSafeTrieRawHashMapBase::destroyImpl() when there's a non-trivial
93   // destructor.
94   //
95   // It would also greatly speed up iteration, if we add that some day, and
96   // allow get() to return one level sooner.
97   //
98   // This would be the algorithm for updating IsSet (after updating Slots):
99   //
100   //     std::atomic<uint64_t> &Bits = IsSet[I.High];
101   //     const uint64_t NewBit = 1ULL << I.Low;
102   //     uint64_t Old = 0;
103   //     while (!Bits.compare_exchange_weak(Old, Old | NewBit))
104   //       ;
105 
106   // For debugging.
107   unsigned StartBit = 0;
108   unsigned NumBits = 0;
109   unsigned Size = 0;
110   friend class llvm::ThreadSafeTrieRawHashMapBase;
111   friend class TrailingObjects;
112 
113 public:
114   /// Linked list for ownership of tries. The pointer is owned by TrieSubtrie.
115   std::atomic<TrieSubtrie *> Next;
116 };
117 } // end namespace
118 
119 std::unique_ptr<TrieSubtrie> TrieSubtrie::create(size_t StartBit,
120                                                  size_t NumBits) {
121   void *Memory = ::operator new(sizeToAlloc(NumBits));
122   TrieSubtrie *S = ::new (Memory) TrieSubtrie(StartBit, NumBits);
123   return std::unique_ptr<TrieSubtrie>(S);
124 }
125 
126 TrieSubtrie::TrieSubtrie(size_t StartBit, size_t NumBits)
127     : TrieNode(true), StartBit(StartBit), NumBits(NumBits), Size(1u << NumBits),
128       Next(nullptr) {
129   for (unsigned I = 0; I < Size; ++I)
130     new (&get(I)) Slot(nullptr);
131 
132   static_assert(
133       std::is_trivially_destructible<LazyAtomicPointer<TrieNode>>::value,
134       "Expected no work in destructor for TrieNode");
135 }
136 
137 // Sink the nodes down sub-trie when the object being inserted collides with
138 // the index of existing object in the trie. In this case, a new sub-trie needs
139 // to be allocated to hold existing object.
140 TrieSubtrie *TrieSubtrie::sink(
141     size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI,
142     function_ref<TrieSubtrie *(std::unique_ptr<TrieSubtrie>)> Saver) {
143   // Create a new sub-trie that points to the existing object with the new
144   // index for the next level.
145   assert(NumSubtrieBits > 0);
146   std::unique_ptr<TrieSubtrie> S = create(StartBit + NumBits, NumSubtrieBits);
147 
148   assert(NewI < Size);
149   S->get(NewI).store(&Content);
150 
151   // Using compare_exchange to atomically add back the new sub-trie to the trie
152   // in the place of the exsiting object.
153   TrieNode *ExistingNode = &Content;
154   assert(I < Size);
155   if (get(I).compare_exchange_strong(ExistingNode, S.get()))
156     return Saver(std::move(S));
157 
158   // Another thread created a subtrie already. Return it and let "S" be
159   // destructed.
160   return cast<TrieSubtrie>(ExistingNode);
161 }
162 
163 class ThreadSafeTrieRawHashMapBase::ImplType final
164     : private TrailingObjects<ThreadSafeTrieRawHashMapBase::ImplType,
165                               TrieSubtrie> {
166 public:
167   static std::unique_ptr<ImplType> create(size_t StartBit, size_t NumBits) {
168     size_t Size = sizeof(ImplType) + TrieSubtrie::sizeToAlloc(NumBits);
169     void *Memory = ::operator new(Size);
170     ImplType *Impl = ::new (Memory) ImplType(StartBit, NumBits);
171     return std::unique_ptr<ImplType>(Impl);
172   }
173 
174   // Save the Subtrie into the ownship list of the trie structure in a
175   // thread-safe way. The ownership transfer is done by compare_exchange the
176   // pointer value inside the unique_ptr.
177   TrieSubtrie *save(std::unique_ptr<TrieSubtrie> S) {
178     assert(!S->Next && "Expected S to a freshly-constructed leaf");
179 
180     TrieSubtrie *CurrentHead = nullptr;
181     // Add ownership of "S" to front of the list, so that Root -> S ->
182     // Root.Next. This works by repeatedly setting S->Next to a candidate value
183     // of Root.Next (initially nullptr), then setting Root.Next to S once the
184     // candidate matches reality.
185     while (!getRoot()->Next.compare_exchange_weak(CurrentHead, S.get()))
186       S->Next.exchange(CurrentHead);
187 
188     // Ownership transferred to subtrie successfully. Release the unique_ptr.
189     return S.release();
190   }
191 
192   // Get the root which is the trailing object.
193   TrieSubtrie *getRoot() { return getTrailingObjects<TrieSubtrie>(); }
194 
195   static void *operator new(size_t Size) { return ::operator new(Size); }
196   void operator delete(void *Ptr) { ::operator delete(Ptr); }
197 
198   /// FIXME: This should take a function that allocates and constructs the
199   /// content lazily (taking the hash as a separate parameter), in case of
200   /// collision.
201   ThreadSafeAllocator<BumpPtrAllocator> ContentAlloc;
202 
203 private:
204   friend class TrailingObjects;
205 
206   ImplType(size_t StartBit, size_t NumBits) {
207     ::new (getRoot()) TrieSubtrie(StartBit, NumBits);
208   }
209 };
210 
211 ThreadSafeTrieRawHashMapBase::ImplType &
212 ThreadSafeTrieRawHashMapBase::getOrCreateImpl() {
213   if (ImplType *Impl = ImplPtr.load())
214     return *Impl;
215 
216   // Create a new ImplType and store it if another thread doesn't do so first.
217   // If another thread wins this one is destroyed locally.
218   std::unique_ptr<ImplType> Impl = ImplType::create(0, NumRootBits);
219   ImplType *ExistingImpl = nullptr;
220 
221   // If the ownership transferred succesfully, release unique_ptr and return
222   // the pointer to the new ImplType.
223   if (ImplPtr.compare_exchange_strong(ExistingImpl, Impl.get()))
224     return *Impl.release();
225 
226   // Already created, return the existing ImplType.
227   return *ExistingImpl;
228 }
229 
230 ThreadSafeTrieRawHashMapBase::PointerBase
231 ThreadSafeTrieRawHashMapBase::find(ArrayRef<uint8_t> Hash) const {
232   assert(!Hash.empty() && "Uninitialized hash");
233 
234   ImplType *Impl = ImplPtr.load();
235   if (!Impl)
236     return PointerBase();
237 
238   TrieSubtrie *S = Impl->getRoot();
239   TrieHashIndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash};
240   size_t Index = IndexGen.next();
241   while (Index != IndexGen.end()) {
242     // Try to set the content.
243     TrieNode *Existing = S->get(Index);
244     if (!Existing)
245       return PointerBase(S, Index, *IndexGen.StartBit);
246 
247     // Check for an exact match.
248     if (auto *ExistingContent = dyn_cast<TrieContent>(Existing))
249       return ExistingContent->getHash() == Hash
250                  ? PointerBase(ExistingContent->getValuePointer())
251                  : PointerBase(S, Index, *IndexGen.StartBit);
252 
253     Index = IndexGen.next();
254     S = cast<TrieSubtrie>(Existing);
255   }
256   llvm_unreachable("failed to locate the node after consuming all hash bytes");
257 }
258 
259 ThreadSafeTrieRawHashMapBase::PointerBase ThreadSafeTrieRawHashMapBase::insert(
260     PointerBase Hint, ArrayRef<uint8_t> Hash,
261     function_ref<const uint8_t *(void *Mem, ArrayRef<uint8_t> Hash)>
262         Constructor) {
263   assert(!Hash.empty() && "Uninitialized hash");
264 
265   ImplType &Impl = getOrCreateImpl();
266   TrieSubtrie *S = Impl.getRoot();
267   TrieHashIndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash};
268   size_t Index;
269   if (Hint.isHint()) {
270     S = static_cast<TrieSubtrie *>(Hint.P);
271     Index = IndexGen.hint(Hint.I, Hint.B);
272   } else {
273     Index = IndexGen.next();
274   }
275 
276   while (Index != IndexGen.end()) {
277     // Load the node from the slot, allocating and calling the constructor if
278     // the slot is empty.
279     bool Generated = false;
280     TrieNode &Existing = S->get(Index).loadOrGenerate([&]() {
281       Generated = true;
282 
283       // Construct the value itself at the tail.
284       uint8_t *Memory = reinterpret_cast<uint8_t *>(
285           Impl.ContentAlloc.Allocate(ContentAllocSize, ContentAllocAlign));
286       const uint8_t *HashStorage = Constructor(Memory + ContentOffset, Hash);
287 
288       // Construct the TrieContent header, passing in the offset to the hash.
289       TrieContent *Content = ::new (Memory)
290           TrieContent(ContentOffset, Hash.size(), HashStorage - Memory);
291       assert(Hash == Content->getHash() && "Hash not properly initialized");
292       return Content;
293     });
294     // If we just generated it, return it!
295     if (Generated)
296       return PointerBase(cast<TrieContent>(Existing).getValuePointer());
297 
298     if (auto *ST = dyn_cast<TrieSubtrie>(&Existing)) {
299       S = ST;
300       Index = IndexGen.next();
301       continue;
302     }
303 
304     // Return the existing content if it's an exact match!
305     auto &ExistingContent = cast<TrieContent>(Existing);
306     if (ExistingContent.getHash() == Hash)
307       return PointerBase(ExistingContent.getValuePointer());
308 
309     // Sink the existing content as long as the indexes match.
310     size_t NextIndex = IndexGen.next();
311     while (NextIndex != IndexGen.end()) {
312       size_t NewIndexForExistingContent =
313           IndexGen.getCollidingBits(ExistingContent.getHash());
314       S = S->sink(Index, ExistingContent, IndexGen.getNumBits(),
315                   NewIndexForExistingContent,
316                   [&Impl](std::unique_ptr<TrieSubtrie> S) {
317                     return Impl.save(std::move(S));
318                   });
319       Index = NextIndex;
320 
321       // Found the difference.
322       if (NextIndex != NewIndexForExistingContent)
323         break;
324 
325       NextIndex = IndexGen.next();
326     }
327   }
328   llvm_unreachable("failed to insert the node after consuming all hash bytes");
329 }
330 
331 ThreadSafeTrieRawHashMapBase::ThreadSafeTrieRawHashMapBase(
332     size_t ContentAllocSize, size_t ContentAllocAlign, size_t ContentOffset,
333     std::optional<size_t> NumRootBits, std::optional<size_t> NumSubtrieBits)
334     : ContentAllocSize(ContentAllocSize), ContentAllocAlign(ContentAllocAlign),
335       ContentOffset(ContentOffset),
336       NumRootBits(NumRootBits ? *NumRootBits : DefaultNumRootBits),
337       NumSubtrieBits(NumSubtrieBits ? *NumSubtrieBits : DefaultNumSubtrieBits),
338       ImplPtr(nullptr) {
339   // Assertion checks for reasonable configuration. The settings below are not
340   // hard limits on most platforms, but a reasonable configuration should fall
341   // within those limits.
342   assert((!NumRootBits || *NumRootBits < 20) &&
343          "Root should have fewer than ~1M slots");
344   assert((!NumSubtrieBits || *NumSubtrieBits < 10) &&
345          "Subtries should have fewer than ~1K slots");
346 }
347 
348 ThreadSafeTrieRawHashMapBase::ThreadSafeTrieRawHashMapBase(
349     ThreadSafeTrieRawHashMapBase &&RHS)
350     : ContentAllocSize(RHS.ContentAllocSize),
351       ContentAllocAlign(RHS.ContentAllocAlign),
352       ContentOffset(RHS.ContentOffset), NumRootBits(RHS.NumRootBits),
353       NumSubtrieBits(RHS.NumSubtrieBits) {
354   // Steal the root from RHS.
355   ImplPtr = RHS.ImplPtr.exchange(nullptr);
356 }
357 
358 ThreadSafeTrieRawHashMapBase::~ThreadSafeTrieRawHashMapBase() {
359   assert(!ImplPtr.load() && "Expected subclass to call destroyImpl()");
360 }
361 
362 void ThreadSafeTrieRawHashMapBase::destroyImpl(
363     function_ref<void(void *)> Destructor) {
364   std::unique_ptr<ImplType> Impl(ImplPtr.exchange(nullptr));
365   if (!Impl)
366     return;
367 
368   // Destroy content nodes throughout trie. Avoid destroying any subtries since
369   // we need TrieNode::classof() to find the content nodes.
370   //
371   // FIXME: Once we have bitsets (see FIXME in TrieSubtrie class), use them
372   // facilitate sparse iteration here.
373   if (Destructor)
374     for (TrieSubtrie *Trie = Impl->getRoot(); Trie; Trie = Trie->Next.load())
375       for (unsigned I = 0; I < Trie->size(); ++I)
376         if (auto *Content = dyn_cast_or_null<TrieContent>(Trie->load(I)))
377           Destructor(Content->getValuePointer());
378 
379   // Destroy the subtries. Incidentally, this destroys them in the reverse order
380   // of saving.
381   TrieSubtrie *Trie = Impl->getRoot()->Next;
382   while (Trie) {
383     TrieSubtrie *Next = Trie->Next.exchange(nullptr);
384     delete Trie;
385     Trie = Next;
386   }
387 }
388 
389 ThreadSafeTrieRawHashMapBase::PointerBase
390 ThreadSafeTrieRawHashMapBase::getRoot() const {
391   ImplType *Impl = ImplPtr.load();
392   if (!Impl)
393     return PointerBase();
394   return PointerBase(Impl->getRoot());
395 }
396 
397 unsigned ThreadSafeTrieRawHashMapBase::getStartBit(
398     ThreadSafeTrieRawHashMapBase::PointerBase P) const {
399   assert(!P.isHint() && "Not a valid trie");
400   if (!P.P)
401     return 0;
402   if (auto *S = dyn_cast<TrieSubtrie>((TrieNode *)P.P))
403     return S->StartBit;
404   return 0;
405 }
406 
407 unsigned ThreadSafeTrieRawHashMapBase::getNumBits(
408     ThreadSafeTrieRawHashMapBase::PointerBase P) const {
409   assert(!P.isHint() && "Not a valid trie");
410   if (!P.P)
411     return 0;
412   if (auto *S = dyn_cast<TrieSubtrie>((TrieNode *)P.P))
413     return S->NumBits;
414   return 0;
415 }
416 
417 unsigned ThreadSafeTrieRawHashMapBase::getNumSlotUsed(
418     ThreadSafeTrieRawHashMapBase::PointerBase P) const {
419   assert(!P.isHint() && "Not a valid trie");
420   if (!P.P)
421     return 0;
422   auto *S = dyn_cast<TrieSubtrie>((TrieNode *)P.P);
423   if (!S)
424     return 0;
425   unsigned Num = 0;
426   for (unsigned I = 0, E = S->size(); I < E; ++I)
427     if (S->load(I))
428       ++Num;
429   return Num;
430 }
431 
432 std::string ThreadSafeTrieRawHashMapBase::getTriePrefixAsString(
433     ThreadSafeTrieRawHashMapBase::PointerBase P) const {
434   assert(!P.isHint() && "Not a valid trie");
435   if (!P.P)
436     return "";
437 
438   auto *S = dyn_cast<TrieSubtrie>((TrieNode *)P.P);
439   if (!S || !S->IsSubtrie)
440     return "";
441 
442   // Find a TrieContent node which has hash stored. Depth search following the
443   // first used slot until a TrieContent node is found.
444   TrieSubtrie *Current = S;
445   TrieContent *Node = nullptr;
446   while (Current) {
447     TrieSubtrie *Next = nullptr;
448     // Find first used slot in the trie.
449     for (unsigned I = 0, E = Current->size(); I < E; ++I) {
450       auto *S = Current->load(I);
451       if (!S)
452         continue;
453 
454       if (auto *Content = dyn_cast<TrieContent>(S))
455         Node = Content;
456       else if (auto *Sub = dyn_cast<TrieSubtrie>(S))
457         Next = Sub;
458       break;
459     }
460 
461     // Found the node.
462     if (Node)
463       break;
464 
465     // Continue to the next level if the node is not found.
466     Current = Next;
467   }
468 
469   assert(Node && "malformed trie, cannot find TrieContent on leaf node");
470   // The prefix for the current trie is the first `StartBit` of the content
471   // stored underneath this subtrie.
472   std::string Str;
473   raw_string_ostream SS(Str);
474 
475   unsigned StartFullBytes = (S->StartBit + 1) / 8 - 1;
476   SS << toHex(toStringRef(Node->getHash()).take_front(StartFullBytes),
477               /*LowerCase=*/true);
478 
479   // For the part of the prefix that doesn't fill a byte, print raw bit values.
480   std::string Bits;
481   for (unsigned I = StartFullBytes * 8, E = S->StartBit; I < E; ++I) {
482     unsigned Index = I / 8;
483     unsigned Offset = 7 - I % 8;
484     Bits.push_back('0' + ((Node->getHash()[Index] >> Offset) & 1));
485   }
486 
487   if (!Bits.empty())
488     SS << "[" << Bits << "]";
489 
490   return SS.str();
491 }
492 
493 unsigned ThreadSafeTrieRawHashMapBase::getNumTries() const {
494   ImplType *Impl = ImplPtr.load();
495   if (!Impl)
496     return 0;
497   unsigned Num = 0;
498   for (TrieSubtrie *Trie = Impl->getRoot(); Trie; Trie = Trie->Next.load())
499     ++Num;
500   return Num;
501 }
502 
503 ThreadSafeTrieRawHashMapBase::PointerBase
504 ThreadSafeTrieRawHashMapBase::getNextTrie(
505     ThreadSafeTrieRawHashMapBase::PointerBase P) const {
506   assert(!P.isHint() && "Not a valid trie");
507   if (!P.P)
508     return PointerBase();
509   auto *S = dyn_cast<TrieSubtrie>((TrieNode *)P.P);
510   if (!S)
511     return PointerBase();
512   if (auto *E = S->Next.load())
513     return PointerBase(E);
514   return PointerBase();
515 }
516