xref: /llvm-project/libc/fuzzing/__support/hashtable_fuzz.cpp (revision e59582b6f8f1be3e675866f6a5d661eb4c8ed448)
1 //===-- hashtable_fuzz.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// Fuzzing test for llvm-libc hashtable implementations.
10 ///
11 //===----------------------------------------------------------------------===//
12 #include "include/llvm-libc-types/ENTRY.h"
13 #include "src/__support/CPP/bit.h"
14 #include "src/__support/CPP/string_view.h"
15 #include "src/__support/HashTable/table.h"
16 #include "src/__support/macros/config.h"
17 
18 namespace LIBC_NAMESPACE_DECL {
19 
20 // A fuzzing payload starts with
21 // - uint16_t: initial capacity for table A
22 // - uint64_t: seed for table A
23 // - uint16_t: initial capacity for table B
24 // - uint64_t: seed for table B
25 // Followed by a sequence of actions:
26 // - CrossCheck: only a single byte valued (4 mod 5)
27 // - Find: a single byte valued (3 mod 5) followed by a null-terminated string
28 // - Insert: a single byte valued (0,1,2 mod 5) followed by a null-terminated
29 // string
30 static constexpr size_t INITIAL_HEADER_SIZE =
31     2 * (sizeof(uint16_t) + sizeof(uint64_t));
32 extern "C" size_t LLVMFuzzerMutate(uint8_t *data, size_t size, size_t max_size);
33 extern "C" size_t LLVMFuzzerCustomMutator(uint8_t *data, size_t size,
34                                           size_t max_size, unsigned int seed) {
35   size = LLVMFuzzerMutate(data, size, max_size);
36   // not enough to read the initial capacities and seeds
37   if (size < INITIAL_HEADER_SIZE)
38     return 0;
39 
40   // skip the initial capacities and seeds
41   size_t i = INITIAL_HEADER_SIZE;
42   while (i < size) {
43     // cross check
44     if (static_cast<uint8_t>(data[i]) % 5 == 4) {
45       // skip the cross check byte
46       ++i;
47       continue;
48     }
49 
50     // find or insert
51     // check if there is enough space for the action byte and the
52     // null-terminator
53     if (i + 2 >= max_size)
54       return i;
55     // skip the action byte
56     ++i;
57     // skip the null-terminated string
58     while (i < max_size && data[i] != 0)
59       ++i;
60     // in the case the string is not null-terminated, null-terminate it
61     if (i == max_size && data[i - 1] != 0) {
62       data[i - 1] = 0;
63       return max_size;
64     }
65 
66     // move to the next action
67     ++i;
68   }
69   // return the new size
70   return i;
71 }
72 
73 // a tagged union
74 struct Action {
75   enum class Tag { Find, Insert, CrossCheck } tag;
76   cpp::string_view key;
77 };
78 
79 static struct {
80   size_t remaining;
81   const char *buffer;
82 
83   template <typename T> T next() {
84     static_assert(cpp::is_integral<T>::value, "T must be an integral type");
85 
86     char data[sizeof(T)];
87 
88     for (size_t i = 0; i < sizeof(T); i++)
89       data[i] = buffer[i];
90     buffer += sizeof(T);
91     remaining -= sizeof(T);
92     return cpp::bit_cast<T>(data);
93   }
94 
95   cpp::string_view next_string() {
96     cpp::string_view result(buffer);
97     buffer = result.end() + 1;
98     remaining -= result.size() + 1;
99     return result;
100   }
101 
102   Action next_action() {
103     uint8_t byte = next<uint8_t>();
104     switch (byte % 5) {
105     case 4:
106       return {Action::Tag::CrossCheck, {}};
107     case 3:
108       return {Action::Tag::Find, next_string()};
109     default:
110       return {Action::Tag::Insert, next_string()};
111     }
112   }
113 } global_status;
114 
115 class HashTable {
116   internal::HashTable *table;
117 
118 public:
119   HashTable(uint64_t size, uint64_t seed)
120       : table(internal::HashTable::allocate(size, seed)) {}
121   HashTable(internal::HashTable *table) : table(table) {}
122   ~HashTable() { internal::HashTable::deallocate(table); }
123   HashTable(HashTable &&other) : table(other.table) { other.table = nullptr; }
124   bool is_valid() const { return table != nullptr; }
125   ENTRY *find(const char *key) { return table->find(key); }
126   ENTRY *insert(const ENTRY &entry) {
127     return internal::HashTable::insert(this->table, entry);
128   }
129   using iterator = internal::HashTable::iterator;
130   iterator begin() const { return table->begin(); }
131   iterator end() const { return table->end(); }
132 };
133 
134 HashTable next_hashtable() {
135   size_t size = global_status.next<uint16_t>();
136   uint64_t seed = global_status.next<uint64_t>();
137   return HashTable(size, seed);
138 }
139 
140 extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
141   global_status.buffer = reinterpret_cast<const char *>(data);
142   global_status.remaining = size;
143   if (global_status.remaining < INITIAL_HEADER_SIZE)
144     return 0;
145 
146   HashTable table_a = next_hashtable();
147   HashTable table_b = next_hashtable();
148   for (;;) {
149     if (global_status.remaining == 0)
150       break;
151     Action action = global_status.next_action();
152     switch (action.tag) {
153     case Action::Tag::Find: {
154       if (static_cast<bool>(table_a.find(action.key.data())) !=
155           static_cast<bool>(table_b.find(action.key.data())))
156         __builtin_trap();
157       break;
158     }
159     case Action::Tag::Insert: {
160       char *ptr = const_cast<char *>(action.key.data());
161       ENTRY *a = table_a.insert(ENTRY{ptr, ptr});
162       ENTRY *b = table_b.insert(ENTRY{ptr, ptr});
163       if (a->data != b->data)
164         __builtin_trap();
165       break;
166     }
167     case Action::Tag::CrossCheck: {
168       for (ENTRY a : table_a)
169         if (const ENTRY *b = table_b.find(a.key); a.data != b->data)
170           __builtin_trap();
171 
172       for (ENTRY b : table_b)
173         if (const ENTRY *a = table_a.find(b.key); a->data != b.data)
174           __builtin_trap();
175 
176       break;
177     }
178     }
179   }
180   return 0;
181 }
182 
183 } // namespace LIBC_NAMESPACE_DECL
184