1 //===- TrieRawHashMap.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/ADT/TrieRawHashMap.h" 10 #include "llvm/ADT/LazyAtomicPointer.h" 11 #include "llvm/ADT/StringExtras.h" 12 #include "llvm/ADT/TrieHashIndexGenerator.h" 13 #include "llvm/Support/Allocator.h" 14 #include "llvm/Support/Casting.h" 15 #include "llvm/Support/Debug.h" 16 #include "llvm/Support/ThreadSafeAllocator.h" 17 #include "llvm/Support/TrailingObjects.h" 18 #include "llvm/Support/raw_ostream.h" 19 #include <memory> 20 21 using namespace llvm; 22 23 namespace { 24 struct TrieNode { 25 const bool IsSubtrie = false; 26 27 TrieNode(bool IsSubtrie) : IsSubtrie(IsSubtrie) {} 28 29 static void *operator new(size_t Size) { return ::operator new(Size); } 30 void operator delete(void *Ptr) { ::operator delete(Ptr); } 31 }; 32 33 struct TrieContent final : public TrieNode { 34 const uint8_t ContentOffset; 35 const uint8_t HashSize; 36 const uint8_t HashOffset; 37 38 void *getValuePointer() const { 39 auto *Content = reinterpret_cast<const uint8_t *>(this) + ContentOffset; 40 return const_cast<uint8_t *>(Content); 41 } 42 43 ArrayRef<uint8_t> getHash() const { 44 auto *Begin = reinterpret_cast<const uint8_t *>(this) + HashOffset; 45 return ArrayRef(Begin, Begin + HashSize); 46 } 47 48 TrieContent(size_t ContentOffset, size_t HashSize, size_t HashOffset) 49 : TrieNode(/*IsSubtrie=*/false), ContentOffset(ContentOffset), 50 HashSize(HashSize), HashOffset(HashOffset) {} 51 52 static bool classof(const TrieNode *TN) { return !TN->IsSubtrie; } 53 }; 54 55 static_assert(sizeof(TrieContent) == 56 ThreadSafeTrieRawHashMapBase::TrieContentBaseSize, 57 "Check header assumption!"); 58 59 class TrieSubtrie final 60 : public TrieNode, 61 private TrailingObjects<TrieSubtrie, LazyAtomicPointer<TrieNode>> { 62 public: 63 using Slot = LazyAtomicPointer<TrieNode>; 64 65 Slot &get(size_t I) { return getTrailingObjects<Slot>()[I]; } 66 TrieNode *load(size_t I) { return get(I).load(); } 67 68 unsigned size() const { return Size; } 69 70 TrieSubtrie * 71 sink(size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI, 72 function_ref<TrieSubtrie *(std::unique_ptr<TrieSubtrie>)> Saver); 73 74 static std::unique_ptr<TrieSubtrie> create(size_t StartBit, size_t NumBits); 75 76 explicit TrieSubtrie(size_t StartBit, size_t NumBits); 77 78 static bool classof(const TrieNode *TN) { return TN->IsSubtrie; } 79 80 static constexpr size_t sizeToAlloc(unsigned NumBits) { 81 assert(NumBits < 20 && "Tries should have fewer than ~1M slots"); 82 unsigned Count = 1u << NumBits; 83 return totalSizeToAlloc<LazyAtomicPointer<TrieNode>>(Count); 84 } 85 86 private: 87 // FIXME: Use a bitset to speed up access: 88 // 89 // std::array<std::atomic<uint64_t>, NumSlots/64> IsSet; 90 // 91 // This will avoid needing to visit sparsely filled slots in 92 // \a ThreadSafeTrieRawHashMapBase::destroyImpl() when there's a non-trivial 93 // destructor. 94 // 95 // It would also greatly speed up iteration, if we add that some day, and 96 // allow get() to return one level sooner. 97 // 98 // This would be the algorithm for updating IsSet (after updating Slots): 99 // 100 // std::atomic<uint64_t> &Bits = IsSet[I.High]; 101 // const uint64_t NewBit = 1ULL << I.Low; 102 // uint64_t Old = 0; 103 // while (!Bits.compare_exchange_weak(Old, Old | NewBit)) 104 // ; 105 106 // For debugging. 107 unsigned StartBit = 0; 108 unsigned NumBits = 0; 109 unsigned Size = 0; 110 friend class llvm::ThreadSafeTrieRawHashMapBase; 111 friend class TrailingObjects; 112 113 public: 114 /// Linked list for ownership of tries. The pointer is owned by TrieSubtrie. 115 std::atomic<TrieSubtrie *> Next; 116 }; 117 } // end namespace 118 119 std::unique_ptr<TrieSubtrie> TrieSubtrie::create(size_t StartBit, 120 size_t NumBits) { 121 void *Memory = ::operator new(sizeToAlloc(NumBits)); 122 TrieSubtrie *S = ::new (Memory) TrieSubtrie(StartBit, NumBits); 123 return std::unique_ptr<TrieSubtrie>(S); 124 } 125 126 TrieSubtrie::TrieSubtrie(size_t StartBit, size_t NumBits) 127 : TrieNode(true), StartBit(StartBit), NumBits(NumBits), Size(1u << NumBits), 128 Next(nullptr) { 129 for (unsigned I = 0; I < Size; ++I) 130 new (&get(I)) Slot(nullptr); 131 132 static_assert( 133 std::is_trivially_destructible<LazyAtomicPointer<TrieNode>>::value, 134 "Expected no work in destructor for TrieNode"); 135 } 136 137 // Sink the nodes down sub-trie when the object being inserted collides with 138 // the index of existing object in the trie. In this case, a new sub-trie needs 139 // to be allocated to hold existing object. 140 TrieSubtrie *TrieSubtrie::sink( 141 size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI, 142 function_ref<TrieSubtrie *(std::unique_ptr<TrieSubtrie>)> Saver) { 143 // Create a new sub-trie that points to the existing object with the new 144 // index for the next level. 145 assert(NumSubtrieBits > 0); 146 std::unique_ptr<TrieSubtrie> S = create(StartBit + NumBits, NumSubtrieBits); 147 148 assert(NewI < Size); 149 S->get(NewI).store(&Content); 150 151 // Using compare_exchange to atomically add back the new sub-trie to the trie 152 // in the place of the exsiting object. 153 TrieNode *ExistingNode = &Content; 154 assert(I < Size); 155 if (get(I).compare_exchange_strong(ExistingNode, S.get())) 156 return Saver(std::move(S)); 157 158 // Another thread created a subtrie already. Return it and let "S" be 159 // destructed. 160 return cast<TrieSubtrie>(ExistingNode); 161 } 162 163 class ThreadSafeTrieRawHashMapBase::ImplType final 164 : private TrailingObjects<ThreadSafeTrieRawHashMapBase::ImplType, 165 TrieSubtrie> { 166 public: 167 static std::unique_ptr<ImplType> create(size_t StartBit, size_t NumBits) { 168 size_t Size = sizeof(ImplType) + TrieSubtrie::sizeToAlloc(NumBits); 169 void *Memory = ::operator new(Size); 170 ImplType *Impl = ::new (Memory) ImplType(StartBit, NumBits); 171 return std::unique_ptr<ImplType>(Impl); 172 } 173 174 // Save the Subtrie into the ownship list of the trie structure in a 175 // thread-safe way. The ownership transfer is done by compare_exchange the 176 // pointer value inside the unique_ptr. 177 TrieSubtrie *save(std::unique_ptr<TrieSubtrie> S) { 178 assert(!S->Next && "Expected S to a freshly-constructed leaf"); 179 180 TrieSubtrie *CurrentHead = nullptr; 181 // Add ownership of "S" to front of the list, so that Root -> S -> 182 // Root.Next. This works by repeatedly setting S->Next to a candidate value 183 // of Root.Next (initially nullptr), then setting Root.Next to S once the 184 // candidate matches reality. 185 while (!getRoot()->Next.compare_exchange_weak(CurrentHead, S.get())) 186 S->Next.exchange(CurrentHead); 187 188 // Ownership transferred to subtrie successfully. Release the unique_ptr. 189 return S.release(); 190 } 191 192 // Get the root which is the trailing object. 193 TrieSubtrie *getRoot() { return getTrailingObjects<TrieSubtrie>(); } 194 195 static void *operator new(size_t Size) { return ::operator new(Size); } 196 void operator delete(void *Ptr) { ::operator delete(Ptr); } 197 198 /// FIXME: This should take a function that allocates and constructs the 199 /// content lazily (taking the hash as a separate parameter), in case of 200 /// collision. 201 ThreadSafeAllocator<BumpPtrAllocator> ContentAlloc; 202 203 private: 204 friend class TrailingObjects; 205 206 ImplType(size_t StartBit, size_t NumBits) { 207 ::new (getRoot()) TrieSubtrie(StartBit, NumBits); 208 } 209 }; 210 211 ThreadSafeTrieRawHashMapBase::ImplType & 212 ThreadSafeTrieRawHashMapBase::getOrCreateImpl() { 213 if (ImplType *Impl = ImplPtr.load()) 214 return *Impl; 215 216 // Create a new ImplType and store it if another thread doesn't do so first. 217 // If another thread wins this one is destroyed locally. 218 std::unique_ptr<ImplType> Impl = ImplType::create(0, NumRootBits); 219 ImplType *ExistingImpl = nullptr; 220 221 // If the ownership transferred succesfully, release unique_ptr and return 222 // the pointer to the new ImplType. 223 if (ImplPtr.compare_exchange_strong(ExistingImpl, Impl.get())) 224 return *Impl.release(); 225 226 // Already created, return the existing ImplType. 227 return *ExistingImpl; 228 } 229 230 ThreadSafeTrieRawHashMapBase::PointerBase 231 ThreadSafeTrieRawHashMapBase::find(ArrayRef<uint8_t> Hash) const { 232 assert(!Hash.empty() && "Uninitialized hash"); 233 234 ImplType *Impl = ImplPtr.load(); 235 if (!Impl) 236 return PointerBase(); 237 238 TrieSubtrie *S = Impl->getRoot(); 239 TrieHashIndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash}; 240 size_t Index = IndexGen.next(); 241 while (Index != IndexGen.end()) { 242 // Try to set the content. 243 TrieNode *Existing = S->get(Index); 244 if (!Existing) 245 return PointerBase(S, Index, *IndexGen.StartBit); 246 247 // Check for an exact match. 248 if (auto *ExistingContent = dyn_cast<TrieContent>(Existing)) 249 return ExistingContent->getHash() == Hash 250 ? PointerBase(ExistingContent->getValuePointer()) 251 : PointerBase(S, Index, *IndexGen.StartBit); 252 253 Index = IndexGen.next(); 254 S = cast<TrieSubtrie>(Existing); 255 } 256 llvm_unreachable("failed to locate the node after consuming all hash bytes"); 257 } 258 259 ThreadSafeTrieRawHashMapBase::PointerBase ThreadSafeTrieRawHashMapBase::insert( 260 PointerBase Hint, ArrayRef<uint8_t> Hash, 261 function_ref<const uint8_t *(void *Mem, ArrayRef<uint8_t> Hash)> 262 Constructor) { 263 assert(!Hash.empty() && "Uninitialized hash"); 264 265 ImplType &Impl = getOrCreateImpl(); 266 TrieSubtrie *S = Impl.getRoot(); 267 TrieHashIndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash}; 268 size_t Index; 269 if (Hint.isHint()) { 270 S = static_cast<TrieSubtrie *>(Hint.P); 271 Index = IndexGen.hint(Hint.I, Hint.B); 272 } else { 273 Index = IndexGen.next(); 274 } 275 276 while (Index != IndexGen.end()) { 277 // Load the node from the slot, allocating and calling the constructor if 278 // the slot is empty. 279 bool Generated = false; 280 TrieNode &Existing = S->get(Index).loadOrGenerate([&]() { 281 Generated = true; 282 283 // Construct the value itself at the tail. 284 uint8_t *Memory = reinterpret_cast<uint8_t *>( 285 Impl.ContentAlloc.Allocate(ContentAllocSize, ContentAllocAlign)); 286 const uint8_t *HashStorage = Constructor(Memory + ContentOffset, Hash); 287 288 // Construct the TrieContent header, passing in the offset to the hash. 289 TrieContent *Content = ::new (Memory) 290 TrieContent(ContentOffset, Hash.size(), HashStorage - Memory); 291 assert(Hash == Content->getHash() && "Hash not properly initialized"); 292 return Content; 293 }); 294 // If we just generated it, return it! 295 if (Generated) 296 return PointerBase(cast<TrieContent>(Existing).getValuePointer()); 297 298 if (auto *ST = dyn_cast<TrieSubtrie>(&Existing)) { 299 S = ST; 300 Index = IndexGen.next(); 301 continue; 302 } 303 304 // Return the existing content if it's an exact match! 305 auto &ExistingContent = cast<TrieContent>(Existing); 306 if (ExistingContent.getHash() == Hash) 307 return PointerBase(ExistingContent.getValuePointer()); 308 309 // Sink the existing content as long as the indexes match. 310 size_t NextIndex = IndexGen.next(); 311 while (NextIndex != IndexGen.end()) { 312 size_t NewIndexForExistingContent = 313 IndexGen.getCollidingBits(ExistingContent.getHash()); 314 S = S->sink(Index, ExistingContent, IndexGen.getNumBits(), 315 NewIndexForExistingContent, 316 [&Impl](std::unique_ptr<TrieSubtrie> S) { 317 return Impl.save(std::move(S)); 318 }); 319 Index = NextIndex; 320 321 // Found the difference. 322 if (NextIndex != NewIndexForExistingContent) 323 break; 324 325 NextIndex = IndexGen.next(); 326 } 327 } 328 llvm_unreachable("failed to insert the node after consuming all hash bytes"); 329 } 330 331 ThreadSafeTrieRawHashMapBase::ThreadSafeTrieRawHashMapBase( 332 size_t ContentAllocSize, size_t ContentAllocAlign, size_t ContentOffset, 333 std::optional<size_t> NumRootBits, std::optional<size_t> NumSubtrieBits) 334 : ContentAllocSize(ContentAllocSize), ContentAllocAlign(ContentAllocAlign), 335 ContentOffset(ContentOffset), 336 NumRootBits(NumRootBits ? *NumRootBits : DefaultNumRootBits), 337 NumSubtrieBits(NumSubtrieBits ? *NumSubtrieBits : DefaultNumSubtrieBits), 338 ImplPtr(nullptr) { 339 // Assertion checks for reasonable configuration. The settings below are not 340 // hard limits on most platforms, but a reasonable configuration should fall 341 // within those limits. 342 assert((!NumRootBits || *NumRootBits < 20) && 343 "Root should have fewer than ~1M slots"); 344 assert((!NumSubtrieBits || *NumSubtrieBits < 10) && 345 "Subtries should have fewer than ~1K slots"); 346 } 347 348 ThreadSafeTrieRawHashMapBase::ThreadSafeTrieRawHashMapBase( 349 ThreadSafeTrieRawHashMapBase &&RHS) 350 : ContentAllocSize(RHS.ContentAllocSize), 351 ContentAllocAlign(RHS.ContentAllocAlign), 352 ContentOffset(RHS.ContentOffset), NumRootBits(RHS.NumRootBits), 353 NumSubtrieBits(RHS.NumSubtrieBits) { 354 // Steal the root from RHS. 355 ImplPtr = RHS.ImplPtr.exchange(nullptr); 356 } 357 358 ThreadSafeTrieRawHashMapBase::~ThreadSafeTrieRawHashMapBase() { 359 assert(!ImplPtr.load() && "Expected subclass to call destroyImpl()"); 360 } 361 362 void ThreadSafeTrieRawHashMapBase::destroyImpl( 363 function_ref<void(void *)> Destructor) { 364 std::unique_ptr<ImplType> Impl(ImplPtr.exchange(nullptr)); 365 if (!Impl) 366 return; 367 368 // Destroy content nodes throughout trie. Avoid destroying any subtries since 369 // we need TrieNode::classof() to find the content nodes. 370 // 371 // FIXME: Once we have bitsets (see FIXME in TrieSubtrie class), use them 372 // facilitate sparse iteration here. 373 if (Destructor) 374 for (TrieSubtrie *Trie = Impl->getRoot(); Trie; Trie = Trie->Next.load()) 375 for (unsigned I = 0; I < Trie->size(); ++I) 376 if (auto *Content = dyn_cast_or_null<TrieContent>(Trie->load(I))) 377 Destructor(Content->getValuePointer()); 378 379 // Destroy the subtries. Incidentally, this destroys them in the reverse order 380 // of saving. 381 TrieSubtrie *Trie = Impl->getRoot()->Next; 382 while (Trie) { 383 TrieSubtrie *Next = Trie->Next.exchange(nullptr); 384 delete Trie; 385 Trie = Next; 386 } 387 } 388 389 ThreadSafeTrieRawHashMapBase::PointerBase 390 ThreadSafeTrieRawHashMapBase::getRoot() const { 391 ImplType *Impl = ImplPtr.load(); 392 if (!Impl) 393 return PointerBase(); 394 return PointerBase(Impl->getRoot()); 395 } 396 397 unsigned ThreadSafeTrieRawHashMapBase::getStartBit( 398 ThreadSafeTrieRawHashMapBase::PointerBase P) const { 399 assert(!P.isHint() && "Not a valid trie"); 400 if (!P.P) 401 return 0; 402 if (auto *S = dyn_cast<TrieSubtrie>((TrieNode *)P.P)) 403 return S->StartBit; 404 return 0; 405 } 406 407 unsigned ThreadSafeTrieRawHashMapBase::getNumBits( 408 ThreadSafeTrieRawHashMapBase::PointerBase P) const { 409 assert(!P.isHint() && "Not a valid trie"); 410 if (!P.P) 411 return 0; 412 if (auto *S = dyn_cast<TrieSubtrie>((TrieNode *)P.P)) 413 return S->NumBits; 414 return 0; 415 } 416 417 unsigned ThreadSafeTrieRawHashMapBase::getNumSlotUsed( 418 ThreadSafeTrieRawHashMapBase::PointerBase P) const { 419 assert(!P.isHint() && "Not a valid trie"); 420 if (!P.P) 421 return 0; 422 auto *S = dyn_cast<TrieSubtrie>((TrieNode *)P.P); 423 if (!S) 424 return 0; 425 unsigned Num = 0; 426 for (unsigned I = 0, E = S->size(); I < E; ++I) 427 if (S->load(I)) 428 ++Num; 429 return Num; 430 } 431 432 std::string ThreadSafeTrieRawHashMapBase::getTriePrefixAsString( 433 ThreadSafeTrieRawHashMapBase::PointerBase P) const { 434 assert(!P.isHint() && "Not a valid trie"); 435 if (!P.P) 436 return ""; 437 438 auto *S = dyn_cast<TrieSubtrie>((TrieNode *)P.P); 439 if (!S || !S->IsSubtrie) 440 return ""; 441 442 // Find a TrieContent node which has hash stored. Depth search following the 443 // first used slot until a TrieContent node is found. 444 TrieSubtrie *Current = S; 445 TrieContent *Node = nullptr; 446 while (Current) { 447 TrieSubtrie *Next = nullptr; 448 // Find first used slot in the trie. 449 for (unsigned I = 0, E = Current->size(); I < E; ++I) { 450 auto *S = Current->load(I); 451 if (!S) 452 continue; 453 454 if (auto *Content = dyn_cast<TrieContent>(S)) 455 Node = Content; 456 else if (auto *Sub = dyn_cast<TrieSubtrie>(S)) 457 Next = Sub; 458 break; 459 } 460 461 // Found the node. 462 if (Node) 463 break; 464 465 // Continue to the next level if the node is not found. 466 Current = Next; 467 } 468 469 assert(Node && "malformed trie, cannot find TrieContent on leaf node"); 470 // The prefix for the current trie is the first `StartBit` of the content 471 // stored underneath this subtrie. 472 std::string Str; 473 raw_string_ostream SS(Str); 474 475 unsigned StartFullBytes = (S->StartBit + 1) / 8 - 1; 476 SS << toHex(toStringRef(Node->getHash()).take_front(StartFullBytes), 477 /*LowerCase=*/true); 478 479 // For the part of the prefix that doesn't fill a byte, print raw bit values. 480 std::string Bits; 481 for (unsigned I = StartFullBytes * 8, E = S->StartBit; I < E; ++I) { 482 unsigned Index = I / 8; 483 unsigned Offset = 7 - I % 8; 484 Bits.push_back('0' + ((Node->getHash()[Index] >> Offset) & 1)); 485 } 486 487 if (!Bits.empty()) 488 SS << "[" << Bits << "]"; 489 490 return SS.str(); 491 } 492 493 unsigned ThreadSafeTrieRawHashMapBase::getNumTries() const { 494 ImplType *Impl = ImplPtr.load(); 495 if (!Impl) 496 return 0; 497 unsigned Num = 0; 498 for (TrieSubtrie *Trie = Impl->getRoot(); Trie; Trie = Trie->Next.load()) 499 ++Num; 500 return Num; 501 } 502 503 ThreadSafeTrieRawHashMapBase::PointerBase 504 ThreadSafeTrieRawHashMapBase::getNextTrie( 505 ThreadSafeTrieRawHashMapBase::PointerBase P) const { 506 assert(!P.isHint() && "Not a valid trie"); 507 if (!P.P) 508 return PointerBase(); 509 auto *S = dyn_cast<TrieSubtrie>((TrieNode *)P.P); 510 if (!S) 511 return PointerBase(); 512 if (auto *E = S->Next.load()) 513 return PointerBase(E); 514 return PointerBase(); 515 } 516