1 /*- 2 * Copyright (c) 2016 Mindaugas Rasiukevicius <rmind at noxt eu> 3 * All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions 7 * are met: 8 * 1. Redistributions of source code must retain the above copyright 9 * notice, this list of conditions and the following disclaimer. 10 * 2. Redistributions in binary form must reproduce the above copyright 11 * notice, this list of conditions and the following disclaimer in the 12 * documentation and/or other materials provided with the distribution. 13 * 14 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND 15 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 17 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE 18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS 20 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 21 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 22 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 23 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF 24 * SUCH DAMAGE. 25 */ 26 27 /* 28 * TODO: Simple linear scan for now (works just well with a few prefixes). 29 * TBD on a better algorithm. 30 */ 31 32 #if defined(_KERNEL) 33 #include <sys/cdefs.h> 34 __KERNEL_RCSID(0, "$NetBSD: lpm.c,v 1.3 2016/12/26 21:16:06 rmind Exp $"); 35 36 #include <sys/param.h> 37 #include <sys/types.h> 38 #include <sys/malloc.h> 39 #include <sys/kmem.h> 40 #else 41 #include <sys/socket.h> 42 #include <arpa/inet.h> 43 44 #include <stdio.h> 45 #include <stdlib.h> 46 #include <stdbool.h> 47 #include <stddef.h> 48 #include <string.h> 49 #include <strings.h> 50 #include <errno.h> 51 #include <assert.h> 52 #define kmem_alloc(a, b) malloc(a) 53 #define kmem_free(a, b) free(a) 54 #define kmem_zalloc(a, b) calloc(a, 1) 55 #endif 56 57 #include "lpm.h" 58 59 #define LPM_MAX_PREFIX (128) 60 #define LPM_MAX_WORDS (LPM_MAX_PREFIX >> 5) 61 #define LPM_TO_WORDS(x) ((x) >> 2) 62 #define LPM_HASH_STEP (8) 63 64 #ifdef DEBUG 65 #define ASSERT assert 66 #else 67 #define ASSERT 68 #endif 69 70 typedef struct lpm_ent { 71 struct lpm_ent *next; 72 void * val; 73 unsigned len; 74 uint8_t key[]; 75 } lpm_ent_t; 76 77 typedef struct { 78 uint32_t hashsize; 79 uint32_t nitems; 80 lpm_ent_t **bucket; 81 } lpm_hmap_t; 82 83 struct lpm { 84 uint32_t bitmask[LPM_MAX_WORDS]; 85 void * defval; 86 lpm_hmap_t prefix[LPM_MAX_PREFIX + 1]; 87 }; 88 89 lpm_t * 90 lpm_create(void) 91 { 92 return kmem_zalloc(sizeof(lpm_t), KM_SLEEP); 93 } 94 95 void 96 lpm_clear(lpm_t *lpm, lpm_dtor_t dtor, void *arg) 97 { 98 for (unsigned n = 0; n <= LPM_MAX_PREFIX; n++) { 99 lpm_hmap_t *hmap = &lpm->prefix[n]; 100 101 if (!hmap->hashsize) { 102 KASSERT(!hmap->bucket); 103 continue; 104 } 105 for (unsigned i = 0; i < hmap->hashsize; i++) { 106 lpm_ent_t *entry = hmap->bucket[i]; 107 108 while (entry) { 109 lpm_ent_t *next = entry->next; 110 111 if (dtor) { 112 dtor(arg, entry->key, 113 entry->len, entry->val); 114 } 115 kmem_free(entry, 116 offsetof(lpm_ent_t, key[entry->len])); 117 entry = next; 118 } 119 } 120 kmem_free(hmap->bucket, hmap->hashsize * sizeof(lpm_ent_t *)); 121 hmap->bucket = NULL; 122 hmap->hashsize = 0; 123 hmap->nitems = 0; 124 } 125 memset(lpm->bitmask, 0, sizeof(lpm->bitmask)); 126 lpm->defval = NULL; 127 } 128 129 void 130 lpm_destroy(lpm_t *lpm) 131 { 132 lpm_clear(lpm, NULL, NULL); 133 kmem_free(lpm, sizeof(*lpm)); 134 } 135 136 /* 137 * fnv1a_hash: Fowler-Noll-Vo hash function (FNV-1a variant). 138 */ 139 static uint32_t 140 fnv1a_hash(const void *buf, size_t len) 141 { 142 uint32_t hash = 2166136261UL; 143 const uint8_t *p = buf; 144 145 while (len--) { 146 hash ^= *p++; 147 hash *= 16777619U; 148 } 149 return hash; 150 } 151 152 static bool 153 hashmap_rehash(lpm_hmap_t *hmap, uint32_t size) 154 { 155 lpm_ent_t **bucket; 156 uint32_t hashsize; 157 158 for (hashsize = 1; hashsize < size; hashsize <<= 1) { 159 continue; 160 } 161 bucket = kmem_zalloc(hashsize * sizeof(lpm_ent_t *), KM_SLEEP); 162 if (bucket == NULL) 163 return false; 164 for (unsigned n = 0; n < hmap->hashsize; n++) { 165 lpm_ent_t *list = hmap->bucket[n]; 166 167 while (list) { 168 lpm_ent_t *entry = list; 169 uint32_t hash = fnv1a_hash(entry->key, entry->len); 170 const size_t i = hash & (hashsize - 1); 171 172 list = entry->next; 173 entry->next = bucket[i]; 174 bucket[i] = entry; 175 } 176 } 177 if (hmap->bucket) 178 kmem_free(hmap->bucket, hmap->hashsize * sizeof(lpm_ent_t *)); 179 hmap->bucket = bucket; 180 hmap->hashsize = hashsize; 181 return true; 182 } 183 184 static lpm_ent_t * 185 hashmap_insert(lpm_hmap_t *hmap, const void *key, size_t len) 186 { 187 const uint32_t target = hmap->nitems + LPM_HASH_STEP; 188 const size_t entlen = offsetof(lpm_ent_t, key[len]); 189 uint32_t hash, i; 190 lpm_ent_t *entry; 191 192 if (hmap->hashsize < target && !hashmap_rehash(hmap, target)) { 193 return NULL; 194 } 195 196 hash = fnv1a_hash(key, len); 197 i = hash & (hmap->hashsize - 1); 198 entry = hmap->bucket[i]; 199 while (entry) { 200 if (entry->len == len && memcmp(entry->key, key, len) == 0) { 201 return entry; 202 } 203 entry = entry->next; 204 } 205 206 if ((entry = kmem_alloc(entlen, KM_SLEEP)) == NULL) 207 return NULL; 208 209 memcpy(entry->key, key, len); 210 entry->next = hmap->bucket[i]; 211 entry->len = len; 212 213 hmap->bucket[i] = entry; 214 hmap->nitems++; 215 return entry; 216 } 217 218 static lpm_ent_t * 219 hashmap_lookup(lpm_hmap_t *hmap, const void *key, size_t len) 220 { 221 const uint32_t hash = fnv1a_hash(key, len); 222 const uint32_t i = hash & (hmap->hashsize - 1); 223 lpm_ent_t *entry = hmap->bucket[i]; 224 225 while (entry) { 226 if (entry->len == len && memcmp(entry->key, key, len) == 0) { 227 return entry; 228 } 229 entry = entry->next; 230 } 231 return NULL; 232 } 233 234 static int 235 hashmap_remove(lpm_hmap_t *hmap, const void *key, size_t len) 236 { 237 const uint32_t hash = fnv1a_hash(key, len); 238 const uint32_t i = hash & (hmap->hashsize - 1); 239 lpm_ent_t *prev = NULL, *entry = hmap->bucket[i]; 240 241 while (entry) { 242 if (entry->len == len && memcmp(entry->key, key, len) == 0) { 243 if (prev) { 244 prev->next = entry->next; 245 } else { 246 hmap->bucket[i] = entry->next; 247 } 248 kmem_free(entry, offsetof(lpm_ent_t, key[len])); 249 return 0; 250 } 251 prev = entry; 252 entry = entry->next; 253 } 254 return -1; 255 } 256 257 /* 258 * compute_prefix: given the address and prefix length, compute and 259 * return the address prefix. 260 */ 261 static inline void 262 compute_prefix(const unsigned nwords, const uint32_t *addr, 263 unsigned preflen, uint32_t *prefix) 264 { 265 uint32_t addr2[4]; 266 267 if ((uintptr_t)addr & 3) { 268 /* Unaligned address: just copy for now. */ 269 memcpy(addr2, addr, nwords * 4); 270 addr = addr2; 271 } 272 for (unsigned i = 0; i < nwords; i++) { 273 if (preflen == 0) { 274 prefix[i] = 0; 275 continue; 276 } 277 if (preflen < 32) { 278 uint32_t mask = htonl(0xffffffff << (32 - preflen)); 279 prefix[i] = addr[i] & mask; 280 preflen = 0; 281 } else { 282 prefix[i] = addr[i]; 283 preflen -= 32; 284 } 285 } 286 } 287 288 /* 289 * lpm_insert: insert the CIDR into the LPM table. 290 * 291 * => Returns zero on success and -1 on failure. 292 */ 293 int 294 lpm_insert(lpm_t *lpm, const void *addr, 295 size_t len, unsigned preflen, void *val) 296 { 297 const unsigned nwords = LPM_TO_WORDS(len); 298 uint32_t prefix[LPM_MAX_WORDS]; 299 lpm_ent_t *entry; 300 301 if (preflen == 0) { 302 /* Default is a special case. */ 303 lpm->defval = val; 304 return 0; 305 } 306 compute_prefix(nwords, addr, preflen, prefix); 307 entry = hashmap_insert(&lpm->prefix[preflen], prefix, len); 308 if (entry) { 309 const unsigned n = --preflen >> 5; 310 lpm->bitmask[n] |= 0x80000000U >> (preflen & 31); 311 entry->val = val; 312 return 0; 313 } 314 return -1; 315 } 316 317 /* 318 * lpm_remove: remove the specified prefix. 319 */ 320 int 321 lpm_remove(lpm_t *lpm, const void *addr, size_t len, unsigned preflen) 322 { 323 const unsigned nwords = LPM_TO_WORDS(len); 324 uint32_t prefix[LPM_MAX_WORDS]; 325 326 if (preflen == 0) { 327 lpm->defval = NULL; 328 return 0; 329 } 330 compute_prefix(nwords, addr, preflen, prefix); 331 return hashmap_remove(&lpm->prefix[preflen], prefix, len); 332 } 333 334 /* 335 * lpm_lookup: find the longest matching prefix given the IP address. 336 * 337 * => Returns the associated value on success or NULL on failure. 338 */ 339 void * 340 lpm_lookup(lpm_t *lpm, const void *addr, size_t len) 341 { 342 const unsigned nwords = LPM_TO_WORDS(len); 343 unsigned i, n = nwords; 344 uint32_t prefix[LPM_MAX_WORDS]; 345 346 while (n--) { 347 uint32_t bitmask = lpm->bitmask[n]; 348 349 while ((i = ffs(bitmask)) != 0) { 350 const unsigned preflen = (32 * n) + (32 - --i); 351 lpm_hmap_t *hmap = &lpm->prefix[preflen]; 352 lpm_ent_t *entry; 353 354 compute_prefix(nwords, addr, preflen, prefix); 355 entry = hashmap_lookup(hmap, prefix, len); 356 if (entry) { 357 return entry->val; 358 } 359 bitmask &= ~(1U << i); 360 } 361 } 362 return lpm->defval; 363 } 364 365 #if !defined(_KERNEL) 366 /* 367 * lpm_strtobin: convert CIDR string to the binary IP address and mask. 368 * 369 * => The address will be in the network byte order. 370 * => Returns 0 on success or -1 on failure. 371 */ 372 int 373 lpm_strtobin(const char *cidr, void *addr, size_t *len, unsigned *preflen) 374 { 375 char *p, buf[INET6_ADDRSTRLEN]; 376 377 strncpy(buf, cidr, sizeof(buf)); 378 buf[sizeof(buf) - 1] = '\0'; 379 380 if ((p = strchr(buf, '/')) != NULL) { 381 const ptrdiff_t off = p - buf; 382 *preflen = atoi(&buf[off + 1]); 383 buf[off] = '\0'; 384 } else { 385 *preflen = LPM_MAX_PREFIX; 386 } 387 388 if (inet_pton(AF_INET6, buf, addr) == 1) { 389 *len = 16; 390 return 0; 391 } 392 if (inet_pton(AF_INET, buf, addr) == 1) { 393 if (*preflen == LPM_MAX_PREFIX) { 394 *preflen = 32; 395 } 396 *len = 4; 397 return 0; 398 } 399 return -1; 400 } 401 #endif 402