xref: /netbsd-src/sys/net/npf/lpm.c (revision f3cfa6f6ce31685c6c4a758bc430e69eb99f50a4)
1 /*-
2  * Copyright (c) 2016 Mindaugas Rasiukevicius <rmind at noxt eu>
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
15  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
20  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
21  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
22  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
23  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
24  * SUCH DAMAGE.
25  */
26 
27 /*
28  * Longest Prefix Match (LPM) library supporting IPv4 and IPv6.
29  *
30  * Algorithm:
31  *
32  * Each prefix gets its own hash map and all added prefixes are saved
33  * in a bitmap.  On a lookup, we perform a linear scan of hash maps,
34  * iterating through the added prefixes only.  Usually, there are only
35  * a few unique prefixes used and such simple algorithm is very efficient.
36  * With many IPv6 prefixes, the linear scan might become a bottleneck.
37  */
38 
39 #if defined(_KERNEL)
40 #include <sys/cdefs.h>
41 __KERNEL_RCSID(0, "$NetBSD: lpm.c,v 1.5 2018/09/29 14:41:36 rmind Exp $");
42 
43 #include <sys/param.h>
44 #include <sys/types.h>
45 #include <sys/malloc.h>
46 #include <sys/kmem.h>
47 #else
48 #include <sys/socket.h>
49 #include <arpa/inet.h>
50 
51 #include <stdio.h>
52 #include <stdlib.h>
53 #include <stdbool.h>
54 #include <stddef.h>
55 #include <string.h>
56 #include <strings.h>
57 #include <errno.h>
58 #include <assert.h>
59 #define kmem_alloc(a, b) malloc(a)
60 #define kmem_free(a, b) free(a)
61 #define kmem_zalloc(a, b) calloc(a, 1)
62 #endif
63 
64 #include "lpm.h"
65 
66 #define	LPM_MAX_PREFIX		(128)
67 #define	LPM_MAX_WORDS		(LPM_MAX_PREFIX >> 5)
68 #define	LPM_TO_WORDS(x)		((x) >> 2)
69 #define	LPM_HASH_STEP		(8)
70 #define	LPM_LEN_IDX(len)	((len) >> 4)
71 
72 #ifdef DEBUG
73 #define	ASSERT			assert
74 #else
75 #define	ASSERT(x)
76 #endif
77 
78 typedef struct lpm_ent {
79 	struct lpm_ent *next;
80 	void *		val;
81 	unsigned	len;
82 	uint8_t		key[];
83 } lpm_ent_t;
84 
85 typedef struct {
86 	unsigned	hashsize;
87 	unsigned	nitems;
88 	lpm_ent_t **	bucket;
89 } lpm_hmap_t;
90 
91 struct lpm {
92 	uint32_t	bitmask[LPM_MAX_WORDS];
93 	void *		defvals[2];
94 	lpm_hmap_t	prefix[LPM_MAX_PREFIX + 1];
95 };
96 
97 static const uint32_t zero_address[LPM_MAX_WORDS];
98 
99 lpm_t *
100 lpm_create(void)
101 {
102 	return kmem_zalloc(sizeof(lpm_t), KM_SLEEP);
103 }
104 
105 void
106 lpm_clear(lpm_t *lpm, lpm_dtor_t dtor, void *arg)
107 {
108 	for (unsigned n = 0; n <= LPM_MAX_PREFIX; n++) {
109 		lpm_hmap_t *hmap = &lpm->prefix[n];
110 
111 		if (!hmap->hashsize) {
112 			KASSERT(!hmap->bucket);
113 			continue;
114 		}
115 		for (unsigned i = 0; i < hmap->hashsize; i++) {
116 			lpm_ent_t *entry = hmap->bucket[i];
117 
118 			while (entry) {
119 				lpm_ent_t *next = entry->next;
120 
121 				if (dtor) {
122 					dtor(arg, entry->key,
123 					    entry->len, entry->val);
124 				}
125 				kmem_free(entry,
126 				    offsetof(lpm_ent_t, key[entry->len]));
127 				entry = next;
128 			}
129 		}
130 		kmem_free(hmap->bucket, hmap->hashsize * sizeof(lpm_ent_t *));
131 		hmap->bucket = NULL;
132 		hmap->hashsize = 0;
133 		hmap->nitems = 0;
134 	}
135 	if (dtor) {
136 		dtor(arg, zero_address, 4, lpm->defvals[0]);
137 		dtor(arg, zero_address, 16, lpm->defvals[1]);
138 	}
139 	memset(lpm->bitmask, 0, sizeof(lpm->bitmask));
140 	memset(lpm->defvals, 0, sizeof(lpm->defvals));
141 }
142 
143 void
144 lpm_destroy(lpm_t *lpm)
145 {
146 	lpm_clear(lpm, NULL, NULL);
147 	kmem_free(lpm, sizeof(*lpm));
148 }
149 
150 /*
151  * fnv1a_hash: Fowler-Noll-Vo hash function (FNV-1a variant).
152  */
153 static uint32_t
154 fnv1a_hash(const void *buf, size_t len)
155 {
156 	uint32_t hash = 2166136261UL;
157 	const uint8_t *p = buf;
158 
159 	while (len--) {
160 		hash ^= *p++;
161 		hash *= 16777619U;
162 	}
163 	return hash;
164 }
165 
166 static bool
167 hashmap_rehash(lpm_hmap_t *hmap, unsigned size)
168 {
169 	lpm_ent_t **bucket;
170 	unsigned hashsize;
171 
172 	for (hashsize = 1; hashsize < size; hashsize <<= 1) {
173 		continue;
174 	}
175 	bucket = kmem_zalloc(hashsize * sizeof(lpm_ent_t *), KM_SLEEP);
176 	for (unsigned n = 0; n < hmap->hashsize; n++) {
177 		lpm_ent_t *list = hmap->bucket[n];
178 
179 		while (list) {
180 			lpm_ent_t *entry = list;
181 			uint32_t hash = fnv1a_hash(entry->key, entry->len);
182 			const unsigned i = hash & (hashsize - 1);
183 
184 			list = entry->next;
185 			entry->next = bucket[i];
186 			bucket[i] = entry;
187 		}
188 	}
189 	if (hmap->bucket)
190 		kmem_free(hmap->bucket, hmap->hashsize * sizeof(lpm_ent_t *));
191 	hmap->bucket = bucket;
192 	hmap->hashsize = hashsize;
193 	return true;
194 }
195 
196 static lpm_ent_t *
197 hashmap_insert(lpm_hmap_t *hmap, const void *key, size_t len)
198 {
199 	const unsigned target = hmap->nitems + LPM_HASH_STEP;
200 	const size_t entlen = offsetof(lpm_ent_t, key[len]);
201 	uint32_t hash, i;
202 	lpm_ent_t *entry;
203 
204 	if (hmap->hashsize < target && !hashmap_rehash(hmap, target)) {
205 		return NULL;
206 	}
207 
208 	hash = fnv1a_hash(key, len);
209 	i = hash & (hmap->hashsize - 1);
210 	entry = hmap->bucket[i];
211 	while (entry) {
212 		if (entry->len == len && memcmp(entry->key, key, len) == 0) {
213 			return entry;
214 		}
215 		entry = entry->next;
216 	}
217 
218 	if ((entry = kmem_alloc(entlen, KM_SLEEP)) != NULL) {
219 		memcpy(entry->key, key, len);
220 		entry->next = hmap->bucket[i];
221 		entry->len = len;
222 
223 		hmap->bucket[i] = entry;
224 		hmap->nitems++;
225 	}
226 	return entry;
227 }
228 
229 static lpm_ent_t *
230 hashmap_lookup(lpm_hmap_t *hmap, const void *key, size_t len)
231 {
232 	const uint32_t hash = fnv1a_hash(key, len);
233 	const unsigned i = hash & (hmap->hashsize - 1);
234 	lpm_ent_t *entry;
235 
236 	if (hmap->hashsize == 0) {
237 		return NULL;
238 	}
239 	entry = hmap->bucket[i];
240 
241 	while (entry) {
242 		if (entry->len == len && memcmp(entry->key, key, len) == 0) {
243 			return entry;
244 		}
245 		entry = entry->next;
246 	}
247 	return NULL;
248 }
249 
250 static int
251 hashmap_remove(lpm_hmap_t *hmap, const void *key, size_t len)
252 {
253 	const uint32_t hash = fnv1a_hash(key, len);
254 	const unsigned i = hash & (hmap->hashsize - 1);
255 	lpm_ent_t *prev = NULL, *entry;
256 
257 	if (hmap->hashsize == 0) {
258 		return -1;
259 	}
260 	entry = hmap->bucket[i];
261 
262 	while (entry) {
263 		if (entry->len == len && memcmp(entry->key, key, len) == 0) {
264 			if (prev) {
265 				prev->next = entry->next;
266 			} else {
267 				hmap->bucket[i] = entry->next;
268 			}
269 			kmem_free(entry, offsetof(lpm_ent_t, key[len]));
270 			return 0;
271 		}
272 		prev = entry;
273 		entry = entry->next;
274 	}
275 	return -1;
276 }
277 
278 /*
279  * compute_prefix: given the address and prefix length, compute and
280  * return the address prefix.
281  */
282 static inline void
283 compute_prefix(const unsigned nwords, const uint32_t *addr,
284     unsigned preflen, uint32_t *prefix)
285 {
286 	uint32_t addr2[4];
287 
288 	if ((uintptr_t)addr & 3) {
289 		/* Unaligned address: just copy for now. */
290 		memcpy(addr2, addr, nwords * 4);
291 		addr = addr2;
292 	}
293 	for (unsigned i = 0; i < nwords; i++) {
294 		if (preflen == 0) {
295 			prefix[i] = 0;
296 			continue;
297 		}
298 		if (preflen < 32) {
299 			uint32_t mask = htonl(0xffffffff << (32 - preflen));
300 			prefix[i] = addr[i] & mask;
301 			preflen = 0;
302 		} else {
303 			prefix[i] = addr[i];
304 			preflen -= 32;
305 		}
306 	}
307 }
308 
309 /*
310  * lpm_insert: insert the CIDR into the LPM table.
311  *
312  * => Returns zero on success and -1 on failure.
313  */
314 int
315 lpm_insert(lpm_t *lpm, const void *addr,
316     size_t len, unsigned preflen, void *val)
317 {
318 	const unsigned nwords = LPM_TO_WORDS(len);
319 	uint32_t prefix[LPM_MAX_WORDS];
320 	lpm_ent_t *entry;
321 	KASSERT(len == 4 || len == 16);
322 
323 	if (preflen == 0) {
324 		/* 0-length prefix is a special case. */
325 		lpm->defvals[LPM_LEN_IDX(len)] = val;
326 		return 0;
327 	}
328 	compute_prefix(nwords, addr, preflen, prefix);
329 	entry = hashmap_insert(&lpm->prefix[preflen], prefix, len);
330 	if (entry) {
331 		const unsigned n = --preflen >> 5;
332 		lpm->bitmask[n] |= 0x80000000U >> (preflen & 31);
333 		entry->val = val;
334 		return 0;
335 	}
336 	return -1;
337 }
338 
339 /*
340  * lpm_remove: remove the specified prefix.
341  */
342 int
343 lpm_remove(lpm_t *lpm, const void *addr, size_t len, unsigned preflen)
344 {
345 	const unsigned nwords = LPM_TO_WORDS(len);
346 	uint32_t prefix[LPM_MAX_WORDS];
347 	KASSERT(len == 4 || len == 16);
348 
349 	if (preflen == 0) {
350 		lpm->defvals[LPM_LEN_IDX(len)] = NULL;
351 		return 0;
352 	}
353 	compute_prefix(nwords, addr, preflen, prefix);
354 	return hashmap_remove(&lpm->prefix[preflen], prefix, len);
355 }
356 
357 /*
358  * lpm_lookup: find the longest matching prefix given the IP address.
359  *
360  * => Returns the associated value on success or NULL on failure.
361  */
362 void *
363 lpm_lookup(lpm_t *lpm, const void *addr, size_t len)
364 {
365 	const unsigned nwords = LPM_TO_WORDS(len);
366 	unsigned i, n = nwords;
367 	uint32_t prefix[LPM_MAX_WORDS];
368 
369 	while (n--) {
370 		uint32_t bitmask = lpm->bitmask[n];
371 
372 		while ((i = ffs(bitmask)) != 0) {
373 			const unsigned preflen = (32 * n) + (32 - --i);
374 			lpm_hmap_t *hmap = &lpm->prefix[preflen];
375 			lpm_ent_t *entry;
376 
377 			compute_prefix(nwords, addr, preflen, prefix);
378 			entry = hashmap_lookup(hmap, prefix, len);
379 			if (entry) {
380 				return entry->val;
381 			}
382 			bitmask &= ~(1U << i);
383 		}
384 	}
385 	return lpm->defvals[LPM_LEN_IDX(len)];
386 }
387 
388 /*
389  * lpm_lookup_prefix: return the value associated with a prefix
390  *
391  * => Returns the associated value on success or NULL on failure.
392  */
393 void *
394 lpm_lookup_prefix(lpm_t *lpm, const void *addr, size_t len, unsigned preflen)
395 {
396 	const unsigned nwords = LPM_TO_WORDS(len);
397 	uint32_t prefix[LPM_MAX_WORDS];
398 	lpm_ent_t *entry;
399 	KASSERT(len == 4 || len == 16);
400 
401 	if (preflen == 0) {
402 		return lpm->defvals[LPM_LEN_IDX(len)];
403 	}
404 	compute_prefix(nwords, addr, preflen, prefix);
405 	entry = hashmap_lookup(&lpm->prefix[preflen], prefix, len);
406 	if (entry) {
407 		return entry->val;
408 	}
409 	return NULL;
410 }
411 
412 #if !defined(_KERNEL)
413 /*
414  * lpm_strtobin: convert CIDR string to the binary IP address and mask.
415  *
416  * => The address will be in the network byte order.
417  * => Returns 0 on success or -1 on failure.
418  */
419 int
420 lpm_strtobin(const char *cidr, void *addr, size_t *len, unsigned *preflen)
421 {
422 	char *p, buf[INET6_ADDRSTRLEN];
423 
424 	strncpy(buf, cidr, sizeof(buf));
425 	buf[sizeof(buf) - 1] = '\0';
426 
427 	if ((p = strchr(buf, '/')) != NULL) {
428 		const ptrdiff_t off = p - buf;
429 		*preflen = atoi(&buf[off + 1]);
430 		buf[off] = '\0';
431 	} else {
432 		*preflen = LPM_MAX_PREFIX;
433 	}
434 
435 	if (inet_pton(AF_INET6, buf, addr) == 1) {
436 		*len = 16;
437 		return 0;
438 	}
439 	if (inet_pton(AF_INET, buf, addr) == 1) {
440 		if (*preflen == LPM_MAX_PREFIX) {
441 			*preflen = 32;
442 		}
443 		*len = 4;
444 		return 0;
445 	}
446 	return -1;
447 }
448 #endif
449