1 /*- 2 * Copyright (c) 2016 Mindaugas Rasiukevicius <rmind at noxt eu> 3 * All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions 7 * are met: 8 * 1. Redistributions of source code must retain the above copyright 9 * notice, this list of conditions and the following disclaimer. 10 * 2. Redistributions in binary form must reproduce the above copyright 11 * notice, this list of conditions and the following disclaimer in the 12 * documentation and/or other materials provided with the distribution. 13 * 14 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND 15 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 17 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE 18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS 20 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 21 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 22 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 23 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF 24 * SUCH DAMAGE. 25 */ 26 27 /* 28 * TODO: Simple linear scan for now (works just well with a few prefixes). 29 * TBD on a better algorithm. 30 */ 31 32 #if defined(_KERNEL) 33 #include <sys/cdefs.h> 34 __KERNEL_RCSID(0, "$NetBSD: lpm.c,v 1.4 2017/06/01 02:45:14 chs Exp $"); 35 36 #include <sys/param.h> 37 #include <sys/types.h> 38 #include <sys/malloc.h> 39 #include <sys/kmem.h> 40 #else 41 #include <sys/socket.h> 42 #include <arpa/inet.h> 43 44 #include <stdio.h> 45 #include <stdlib.h> 46 #include <stdbool.h> 47 #include <stddef.h> 48 #include <string.h> 49 #include <strings.h> 50 #include <errno.h> 51 #include <assert.h> 52 #define kmem_alloc(a, b) malloc(a) 53 #define kmem_free(a, b) free(a) 54 #define kmem_zalloc(a, b) calloc(a, 1) 55 #endif 56 57 #include "lpm.h" 58 59 #define LPM_MAX_PREFIX (128) 60 #define LPM_MAX_WORDS (LPM_MAX_PREFIX >> 5) 61 #define LPM_TO_WORDS(x) ((x) >> 2) 62 #define LPM_HASH_STEP (8) 63 64 #ifdef DEBUG 65 #define ASSERT assert 66 #else 67 #define ASSERT 68 #endif 69 70 typedef struct lpm_ent { 71 struct lpm_ent *next; 72 void * val; 73 unsigned len; 74 uint8_t key[]; 75 } lpm_ent_t; 76 77 typedef struct { 78 uint32_t hashsize; 79 uint32_t nitems; 80 lpm_ent_t **bucket; 81 } lpm_hmap_t; 82 83 struct lpm { 84 uint32_t bitmask[LPM_MAX_WORDS]; 85 void * defval; 86 lpm_hmap_t prefix[LPM_MAX_PREFIX + 1]; 87 }; 88 89 lpm_t * 90 lpm_create(void) 91 { 92 return kmem_zalloc(sizeof(lpm_t), KM_SLEEP); 93 } 94 95 void 96 lpm_clear(lpm_t *lpm, lpm_dtor_t dtor, void *arg) 97 { 98 for (unsigned n = 0; n <= LPM_MAX_PREFIX; n++) { 99 lpm_hmap_t *hmap = &lpm->prefix[n]; 100 101 if (!hmap->hashsize) { 102 KASSERT(!hmap->bucket); 103 continue; 104 } 105 for (unsigned i = 0; i < hmap->hashsize; i++) { 106 lpm_ent_t *entry = hmap->bucket[i]; 107 108 while (entry) { 109 lpm_ent_t *next = entry->next; 110 111 if (dtor) { 112 dtor(arg, entry->key, 113 entry->len, entry->val); 114 } 115 kmem_free(entry, 116 offsetof(lpm_ent_t, key[entry->len])); 117 entry = next; 118 } 119 } 120 kmem_free(hmap->bucket, hmap->hashsize * sizeof(lpm_ent_t *)); 121 hmap->bucket = NULL; 122 hmap->hashsize = 0; 123 hmap->nitems = 0; 124 } 125 memset(lpm->bitmask, 0, sizeof(lpm->bitmask)); 126 lpm->defval = NULL; 127 } 128 129 void 130 lpm_destroy(lpm_t *lpm) 131 { 132 lpm_clear(lpm, NULL, NULL); 133 kmem_free(lpm, sizeof(*lpm)); 134 } 135 136 /* 137 * fnv1a_hash: Fowler-Noll-Vo hash function (FNV-1a variant). 138 */ 139 static uint32_t 140 fnv1a_hash(const void *buf, size_t len) 141 { 142 uint32_t hash = 2166136261UL; 143 const uint8_t *p = buf; 144 145 while (len--) { 146 hash ^= *p++; 147 hash *= 16777619U; 148 } 149 return hash; 150 } 151 152 static bool 153 hashmap_rehash(lpm_hmap_t *hmap, uint32_t size) 154 { 155 lpm_ent_t **bucket; 156 uint32_t hashsize; 157 158 for (hashsize = 1; hashsize < size; hashsize <<= 1) { 159 continue; 160 } 161 bucket = kmem_zalloc(hashsize * sizeof(lpm_ent_t *), KM_SLEEP); 162 for (unsigned n = 0; n < hmap->hashsize; n++) { 163 lpm_ent_t *list = hmap->bucket[n]; 164 165 while (list) { 166 lpm_ent_t *entry = list; 167 uint32_t hash = fnv1a_hash(entry->key, entry->len); 168 const size_t i = hash & (hashsize - 1); 169 170 list = entry->next; 171 entry->next = bucket[i]; 172 bucket[i] = entry; 173 } 174 } 175 if (hmap->bucket) 176 kmem_free(hmap->bucket, hmap->hashsize * sizeof(lpm_ent_t *)); 177 hmap->bucket = bucket; 178 hmap->hashsize = hashsize; 179 return true; 180 } 181 182 static lpm_ent_t * 183 hashmap_insert(lpm_hmap_t *hmap, const void *key, size_t len) 184 { 185 const uint32_t target = hmap->nitems + LPM_HASH_STEP; 186 const size_t entlen = offsetof(lpm_ent_t, key[len]); 187 uint32_t hash, i; 188 lpm_ent_t *entry; 189 190 if (hmap->hashsize < target && !hashmap_rehash(hmap, target)) { 191 return NULL; 192 } 193 194 hash = fnv1a_hash(key, len); 195 i = hash & (hmap->hashsize - 1); 196 entry = hmap->bucket[i]; 197 while (entry) { 198 if (entry->len == len && memcmp(entry->key, key, len) == 0) { 199 return entry; 200 } 201 entry = entry->next; 202 } 203 204 entry = kmem_alloc(entlen, KM_SLEEP); 205 memcpy(entry->key, key, len); 206 entry->next = hmap->bucket[i]; 207 entry->len = len; 208 209 hmap->bucket[i] = entry; 210 hmap->nitems++; 211 return entry; 212 } 213 214 static lpm_ent_t * 215 hashmap_lookup(lpm_hmap_t *hmap, const void *key, size_t len) 216 { 217 const uint32_t hash = fnv1a_hash(key, len); 218 const uint32_t i = hash & (hmap->hashsize - 1); 219 lpm_ent_t *entry = hmap->bucket[i]; 220 221 while (entry) { 222 if (entry->len == len && memcmp(entry->key, key, len) == 0) { 223 return entry; 224 } 225 entry = entry->next; 226 } 227 return NULL; 228 } 229 230 static int 231 hashmap_remove(lpm_hmap_t *hmap, const void *key, size_t len) 232 { 233 const uint32_t hash = fnv1a_hash(key, len); 234 const uint32_t i = hash & (hmap->hashsize - 1); 235 lpm_ent_t *prev = NULL, *entry = hmap->bucket[i]; 236 237 while (entry) { 238 if (entry->len == len && memcmp(entry->key, key, len) == 0) { 239 if (prev) { 240 prev->next = entry->next; 241 } else { 242 hmap->bucket[i] = entry->next; 243 } 244 kmem_free(entry, offsetof(lpm_ent_t, key[len])); 245 return 0; 246 } 247 prev = entry; 248 entry = entry->next; 249 } 250 return -1; 251 } 252 253 /* 254 * compute_prefix: given the address and prefix length, compute and 255 * return the address prefix. 256 */ 257 static inline void 258 compute_prefix(const unsigned nwords, const uint32_t *addr, 259 unsigned preflen, uint32_t *prefix) 260 { 261 uint32_t addr2[4]; 262 263 if ((uintptr_t)addr & 3) { 264 /* Unaligned address: just copy for now. */ 265 memcpy(addr2, addr, nwords * 4); 266 addr = addr2; 267 } 268 for (unsigned i = 0; i < nwords; i++) { 269 if (preflen == 0) { 270 prefix[i] = 0; 271 continue; 272 } 273 if (preflen < 32) { 274 uint32_t mask = htonl(0xffffffff << (32 - preflen)); 275 prefix[i] = addr[i] & mask; 276 preflen = 0; 277 } else { 278 prefix[i] = addr[i]; 279 preflen -= 32; 280 } 281 } 282 } 283 284 /* 285 * lpm_insert: insert the CIDR into the LPM table. 286 * 287 * => Returns zero on success and -1 on failure. 288 */ 289 int 290 lpm_insert(lpm_t *lpm, const void *addr, 291 size_t len, unsigned preflen, void *val) 292 { 293 const unsigned nwords = LPM_TO_WORDS(len); 294 uint32_t prefix[LPM_MAX_WORDS]; 295 lpm_ent_t *entry; 296 297 if (preflen == 0) { 298 /* Default is a special case. */ 299 lpm->defval = val; 300 return 0; 301 } 302 compute_prefix(nwords, addr, preflen, prefix); 303 entry = hashmap_insert(&lpm->prefix[preflen], prefix, len); 304 if (entry) { 305 const unsigned n = --preflen >> 5; 306 lpm->bitmask[n] |= 0x80000000U >> (preflen & 31); 307 entry->val = val; 308 return 0; 309 } 310 return -1; 311 } 312 313 /* 314 * lpm_remove: remove the specified prefix. 315 */ 316 int 317 lpm_remove(lpm_t *lpm, const void *addr, size_t len, unsigned preflen) 318 { 319 const unsigned nwords = LPM_TO_WORDS(len); 320 uint32_t prefix[LPM_MAX_WORDS]; 321 322 if (preflen == 0) { 323 lpm->defval = NULL; 324 return 0; 325 } 326 compute_prefix(nwords, addr, preflen, prefix); 327 return hashmap_remove(&lpm->prefix[preflen], prefix, len); 328 } 329 330 /* 331 * lpm_lookup: find the longest matching prefix given the IP address. 332 * 333 * => Returns the associated value on success or NULL on failure. 334 */ 335 void * 336 lpm_lookup(lpm_t *lpm, const void *addr, size_t len) 337 { 338 const unsigned nwords = LPM_TO_WORDS(len); 339 unsigned i, n = nwords; 340 uint32_t prefix[LPM_MAX_WORDS]; 341 342 while (n--) { 343 uint32_t bitmask = lpm->bitmask[n]; 344 345 while ((i = ffs(bitmask)) != 0) { 346 const unsigned preflen = (32 * n) + (32 - --i); 347 lpm_hmap_t *hmap = &lpm->prefix[preflen]; 348 lpm_ent_t *entry; 349 350 compute_prefix(nwords, addr, preflen, prefix); 351 entry = hashmap_lookup(hmap, prefix, len); 352 if (entry) { 353 return entry->val; 354 } 355 bitmask &= ~(1U << i); 356 } 357 } 358 return lpm->defval; 359 } 360 361 #if !defined(_KERNEL) 362 /* 363 * lpm_strtobin: convert CIDR string to the binary IP address and mask. 364 * 365 * => The address will be in the network byte order. 366 * => Returns 0 on success or -1 on failure. 367 */ 368 int 369 lpm_strtobin(const char *cidr, void *addr, size_t *len, unsigned *preflen) 370 { 371 char *p, buf[INET6_ADDRSTRLEN]; 372 373 strncpy(buf, cidr, sizeof(buf)); 374 buf[sizeof(buf) - 1] = '\0'; 375 376 if ((p = strchr(buf, '/')) != NULL) { 377 const ptrdiff_t off = p - buf; 378 *preflen = atoi(&buf[off + 1]); 379 buf[off] = '\0'; 380 } else { 381 *preflen = LPM_MAX_PREFIX; 382 } 383 384 if (inet_pton(AF_INET6, buf, addr) == 1) { 385 *len = 16; 386 return 0; 387 } 388 if (inet_pton(AF_INET, buf, addr) == 1) { 389 if (*preflen == LPM_MAX_PREFIX) { 390 *preflen = 32; 391 } 392 *len = 4; 393 return 0; 394 } 395 return -1; 396 } 397 #endif 398