xref: /netbsd-src/external/mpl/bind/dist/lib/dns/compress.c (revision bcda20f65a8566e103791ec395f7f499ef322704)
1 /*	$NetBSD: compress.c,v 1.10 2025/01/26 16:25:22 christos Exp $	*/
2 
3 /*
4  * Copyright (C) Internet Systems Consortium, Inc. ("ISC")
5  *
6  * SPDX-License-Identifier: MPL-2.0
7  *
8  * This Source Code Form is subject to the terms of the Mozilla Public
9  * License, v. 2.0. If a copy of the MPL was not distributed with this
10  * file, you can obtain one at https://mozilla.org/MPL/2.0/.
11  *
12  * See the COPYRIGHT file distributed with this work for additional
13  * information regarding copyright ownership.
14  */
15 
16 #include <stdbool.h>
17 #include <stdint.h>
18 #include <string.h>
19 
20 #include <isc/ascii.h>
21 #include <isc/buffer.h>
22 #include <isc/hash.h>
23 #include <isc/mem.h>
24 #include <isc/util.h>
25 
26 #include <dns/compress.h>
27 #include <dns/name.h>
28 
29 #define HASH_INIT_DJB2 5381
30 
31 #define CCTX_MAGIC    ISC_MAGIC('C', 'C', 'T', 'X')
32 #define CCTX_VALID(x) ISC_MAGIC_VALID(x, CCTX_MAGIC)
33 
34 void
35 dns_compress_init(dns_compress_t *cctx, isc_mem_t *mctx,
36 		  dns_compress_flags_t flags) {
37 	dns_compress_slot_t *set = NULL;
38 	uint16_t mask;
39 
40 	REQUIRE(cctx != NULL);
41 	REQUIRE(mctx != NULL);
42 
43 	if ((flags & DNS_COMPRESS_LARGE) != 0) {
44 		size_t count = (1 << DNS_COMPRESS_LARGEBITS);
45 		mask = count - 1;
46 		set = isc_mem_callocate(mctx, count, sizeof(*set));
47 	} else {
48 		mask = ARRAY_SIZE(cctx->smallset) - 1;
49 		set = cctx->smallset;
50 	}
51 
52 	/*
53 	 * The lifetime of this object is limited to the stack frame of the
54 	 * caller, so we don't need to attach to the memory context.
55 	 */
56 	*cctx = (dns_compress_t){
57 		.magic = CCTX_MAGIC,
58 		.flags = flags | DNS_COMPRESS_PERMITTED,
59 		.mctx = mctx,
60 		.mask = mask,
61 		.set = set,
62 	};
63 }
64 
65 void
66 dns_compress_invalidate(dns_compress_t *cctx) {
67 	REQUIRE(CCTX_VALID(cctx));
68 	if (cctx->set != cctx->smallset) {
69 		isc_mem_free(cctx->mctx, cctx->set);
70 	}
71 	*cctx = (dns_compress_t){ 0 };
72 }
73 
74 void
75 dns_compress_setpermitted(dns_compress_t *cctx, bool permitted) {
76 	REQUIRE(CCTX_VALID(cctx));
77 	if (permitted) {
78 		cctx->flags |= DNS_COMPRESS_PERMITTED;
79 	} else {
80 		cctx->flags &= ~DNS_COMPRESS_PERMITTED;
81 	}
82 }
83 
84 bool
85 dns_compress_getpermitted(dns_compress_t *cctx) {
86 	REQUIRE(CCTX_VALID(cctx));
87 	return (cctx->flags & DNS_COMPRESS_PERMITTED) != 0;
88 }
89 
90 /*
91  * Our hash value needs to cover the entire suffix of a name, and we need
92  * to calculate it one label at a time. So this function mixes a label into
93  * an existing hash. (We don't use isc_hash32() because the djb2 hash is a
94  * lot faster, and we limit the impact of collision attacks by restricting
95  * the size and occupancy of the hash set.) The accumulator is 32 bits to
96  * keep more of the fun mixing that happens in the upper bits.
97  */
98 static uint16_t
99 hash_label(uint16_t init, uint8_t *ptr, bool sensitive) {
100 	unsigned int len = ptr[0] + 1;
101 	uint32_t hash = init;
102 
103 	if (sensitive) {
104 		while (len-- > 0) {
105 			hash = hash * 33 + *ptr++;
106 		}
107 	} else {
108 		/* using the autovectorize-friendly tolower() */
109 		while (len-- > 0) {
110 			hash = hash * 33 + isc__ascii_tolower1(*ptr++);
111 		}
112 	}
113 
114 	return isc_hash_bits32(hash, 16);
115 }
116 
117 static bool
118 match_wirename(uint8_t *a, uint8_t *b, unsigned int len, bool sensitive) {
119 	if (sensitive) {
120 		return memcmp(a, b, len) == 0;
121 	} else {
122 		/* label lengths are < 'A' so unaffected by tolower() */
123 		return isc_ascii_lowerequal(a, b, len);
124 	}
125 }
126 
127 /*
128  * We have found a hash set entry whose hash value matches the current
129  * suffix of our name, which is passed to this function via `sptr` and
130  * `slen`. We need to verify that the suffix in the message (referred to
131  * by `new_coff`) actually matches, in case of hash collisions.
132  *
133  * We know that the previous suffix of this name (after the first label)
134  * occurs in the message at `old_coff`, and all the compression offsets in
135  * the hash set and in the message refer to the first occurrence of a
136  * particular name or suffix.
137  *
138  * First, we need to match the label that was just added to our suffix,
139  * and second, verify that it is followed by the previous suffix.
140  *
141  * There are a few ways to match the previous suffix:
142  *
143  * When the first occurrence of this suffix is also the first occurrence
144  * of the previous suffix, `old_coff` points just after the new label.
145  *
146  * Otherwise, if this suffix occurs in a compressed name, it will be
147  * followed by a compression pointer that refers to the previous suffix,
148  * which must be equal to `old_coff`.
149  *
150  * The final possibility is that this suffix occurs in an uncompressed
151  * name, so we have to compare the rest of the suffix in full.
152  *
153  * A special case is when this suffix is a TLD. That can be handled by
154  * the case for uncompressed names, but it is common enough that it is
155  * worth taking a short cut. (In the TLD case, the `old_coff` will be
156  * zero, and the quick checks for the previous suffix will fail.)
157  */
158 static bool
159 match_suffix(isc_buffer_t *buffer, unsigned int new_coff, uint8_t *sptr,
160 	     unsigned int slen, unsigned int old_coff, bool sensitive) {
161 	uint8_t pptr[] = { 0xC0 | (old_coff >> 8), old_coff & 0xff };
162 	uint8_t *bptr = isc_buffer_base(buffer);
163 	unsigned int blen = isc_buffer_usedlength(buffer);
164 	unsigned int llen = sptr[0] + 1;
165 
166 	INSIST(llen <= 64 && llen < slen);
167 
168 	if (blen < new_coff + llen) {
169 		return false;
170 	}
171 
172 	blen -= new_coff;
173 	bptr += new_coff;
174 
175 	/* does the first label of the suffix appear here? */
176 	if (!match_wirename(bptr, sptr, llen, sensitive)) {
177 		return false;
178 	}
179 
180 	/* is this label followed by the previously matched suffix? */
181 	if (old_coff == new_coff + llen) {
182 		return true;
183 	}
184 
185 	blen -= llen;
186 	bptr += llen;
187 	slen -= llen;
188 	sptr += llen;
189 
190 	/* are both labels followed by the root label? */
191 	if (blen >= 1 && slen == 1 && bptr[0] == 0 && sptr[0] == 0) {
192 		return true;
193 	}
194 
195 	/* is this label followed by a pointer to the previous match? */
196 	if (blen >= 2 && bptr[0] == pptr[0] && bptr[1] == pptr[1]) {
197 		return true;
198 	}
199 
200 	/* is this label followed by a copy of the rest of the suffix? */
201 	return blen >= slen && match_wirename(bptr, sptr, slen, sensitive);
202 }
203 
204 /*
205  * Robin Hood hashing aims to minimize probe distance when inserting a
206  * new element by ensuring that the new element does not have a worse
207  * probe distance than any other element in its probe sequence. During
208  * insertion, if an existing element is encountered with a shorter
209  * probe distance, it is swapped with the new element, and insertion
210  * continues with the displaced element.
211  */
212 static unsigned int
213 probe_distance(dns_compress_t *cctx, unsigned int slot) {
214 	return (slot - cctx->set[slot].hash) & cctx->mask;
215 }
216 
217 static unsigned int
218 slot_index(dns_compress_t *cctx, unsigned int hash, unsigned int probe) {
219 	return (hash + probe) & cctx->mask;
220 }
221 
222 static bool
223 insert_label(dns_compress_t *cctx, isc_buffer_t *buffer, const dns_name_t *name,
224 	     unsigned int label, uint16_t hash, unsigned int probe) {
225 	/*
226 	 * hash set entries must have valid compression offsets
227 	 * and the hash set must not get too full (75% load)
228 	 */
229 	unsigned int prefix_len = name->offsets[label];
230 	unsigned int coff = isc_buffer_usedlength(buffer) + prefix_len;
231 	if (coff >= 0x4000 || cctx->count > cctx->mask * 3 / 4) {
232 		return false;
233 	}
234 	for (;;) {
235 		unsigned int slot = slot_index(cctx, hash, probe);
236 		/* we can stop when we find an empty slot */
237 		if (cctx->set[slot].coff == 0) {
238 			cctx->set[slot].hash = hash;
239 			cctx->set[slot].coff = coff;
240 			cctx->count++;
241 			return true;
242 		}
243 		/* he steals from the rich and gives to the poor */
244 		if (probe > probe_distance(cctx, slot)) {
245 			probe = probe_distance(cctx, slot);
246 			ISC_SWAP(cctx->set[slot].hash, hash);
247 			ISC_SWAP(cctx->set[slot].coff, coff);
248 		}
249 		probe++;
250 	}
251 }
252 
253 /*
254  * Add the unmatched prefix of the name to the hash set.
255  */
256 static void
257 insert(dns_compress_t *cctx, isc_buffer_t *buffer, const dns_name_t *name,
258        unsigned int label, uint16_t hash, unsigned int probe) {
259 	bool sensitive = (cctx->flags & DNS_COMPRESS_CASE) != 0;
260 	/*
261 	 * this insertion loop continues from the search loop inside
262 	 * dns_compress_name() below, iterating over the remaining labels
263 	 * of the name and accumulating the hash in the same manner
264 	 */
265 	while (insert_label(cctx, buffer, name, label, hash, probe) &&
266 	       label-- > 0)
267 	{
268 		unsigned int prefix_len = name->offsets[label];
269 		uint8_t *suffix_ptr = name->ndata + prefix_len;
270 		hash = hash_label(hash, suffix_ptr, sensitive);
271 		probe = 0;
272 	}
273 }
274 
275 void
276 dns_compress_name(dns_compress_t *cctx, isc_buffer_t *buffer,
277 		  const dns_name_t *name, unsigned int *return_prefix,
278 		  unsigned int *return_coff) {
279 	REQUIRE(CCTX_VALID(cctx));
280 	REQUIRE(ISC_BUFFER_VALID(buffer));
281 	REQUIRE(dns_name_isabsolute(name));
282 	REQUIRE(name->labels > 0);
283 	REQUIRE(name->offsets != NULL);
284 	REQUIRE(return_prefix != NULL);
285 	REQUIRE(return_coff != NULL);
286 	REQUIRE(*return_coff == 0);
287 
288 	if ((cctx->flags & DNS_COMPRESS_DISABLED) != 0) {
289 		return;
290 	}
291 
292 	bool sensitive = (cctx->flags & DNS_COMPRESS_CASE) != 0;
293 
294 	uint16_t hash = HASH_INIT_DJB2;
295 	unsigned int label = name->labels - 1; /* skip the root label */
296 
297 	/*
298 	 * find out how much of the name's suffix is in the hash set,
299 	 * stepping backwards from the end one label at a time
300 	 */
301 	while (label-- > 0) {
302 		unsigned int prefix_len = name->offsets[label];
303 		unsigned int suffix_len = name->length - prefix_len;
304 		uint8_t *suffix_ptr = name->ndata + prefix_len;
305 		hash = hash_label(hash, suffix_ptr, sensitive);
306 
307 		for (unsigned int probe = 0; true; probe++) {
308 			unsigned int slot = slot_index(cctx, hash, probe);
309 			unsigned int coff = cctx->set[slot].coff;
310 
311 			/*
312 			 * if we would have inserted this entry here (as in
313 			 * insert_label() above), our suffix cannot be in the
314 			 * hash set, so stop searching and switch to inserting
315 			 * the rest of the name (its prefix) into the set
316 			 */
317 			if (coff == 0 || probe > probe_distance(cctx, slot)) {
318 				insert(cctx, buffer, name, label, hash, probe);
319 				return;
320 			}
321 
322 			/*
323 			 * this slot matches, so provisionally set the
324 			 * return values and continue with the next label
325 			 */
326 			if (hash == cctx->set[slot].hash &&
327 			    match_suffix(buffer, coff, suffix_ptr, suffix_len,
328 					 *return_coff, sensitive))
329 			{
330 				*return_coff = coff;
331 				*return_prefix = prefix_len;
332 				break;
333 			}
334 		}
335 	}
336 }
337 
338 void
339 dns_compress_rollback(dns_compress_t *cctx, unsigned int coff) {
340 	REQUIRE(CCTX_VALID(cctx));
341 
342 	for (unsigned int slot = 0; slot <= cctx->mask; slot++) {
343 		if (cctx->set[slot].coff < coff) {
344 			continue;
345 		}
346 		/*
347 		 * The next few elements might be part of the deleted element's
348 		 * probe sequence, so we slide them down to overwrite the entry
349 		 * we are deleting and preserve the probe sequence. Moving an
350 		 * element to the previous slot reduces its probe distance, so
351 		 * we stop when we find an element whose probe distance is zero.
352 		 */
353 		unsigned int prev = slot;
354 		unsigned int next = slot_index(cctx, prev, 1);
355 		while (cctx->set[next].coff != 0 &&
356 		       probe_distance(cctx, next) != 0)
357 		{
358 			cctx->set[prev] = cctx->set[next];
359 			prev = next;
360 			next = slot_index(cctx, prev, 1);
361 		}
362 		cctx->set[prev].coff = 0;
363 		cctx->set[prev].hash = 0;
364 		cctx->count--;
365 	}
366 }
367