xref: /openbsd-src/sys/netinet/ip_ipsp.c (revision f1dd7b858388b4a23f4f67a4957ec5ff656ebbe8)
1 /*	$OpenBSD: ip_ipsp.c,v 1.238 2021/03/10 10:21:49 jsg Exp $	*/
2 /*
3  * The authors of this code are John Ioannidis (ji@tla.org),
4  * Angelos D. Keromytis (kermit@csd.uch.gr),
5  * Niels Provos (provos@physnet.uni-hamburg.de) and
6  * Niklas Hallqvist (niklas@appli.se).
7  *
8  * The original version of this code was written by John Ioannidis
9  * for BSD/OS in Athens, Greece, in November 1995.
10  *
11  * Ported to OpenBSD and NetBSD, with additional transforms, in December 1996,
12  * by Angelos D. Keromytis.
13  *
14  * Additional transforms and features in 1997 and 1998 by Angelos D. Keromytis
15  * and Niels Provos.
16  *
17  * Additional features in 1999 by Angelos D. Keromytis and Niklas Hallqvist.
18  *
19  * Copyright (c) 1995, 1996, 1997, 1998, 1999 by John Ioannidis,
20  * Angelos D. Keromytis and Niels Provos.
21  * Copyright (c) 1999 Niklas Hallqvist.
22  * Copyright (c) 2001, Angelos D. Keromytis.
23  *
24  * Permission to use, copy, and modify this software with or without fee
25  * is hereby granted, provided that this entire notice is included in
26  * all copies of any software which is or includes a copy or
27  * modification of this software.
28  * You may use this code under the GNU public license if you so wish. Please
29  * contribute changes back to the authors under this freer than GPL license
30  * so that we may further the use of strong encryption without limitations to
31  * all.
32  *
33  * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR
34  * IMPLIED WARRANTY. IN PARTICULAR, NONE OF THE AUTHORS MAKES ANY
35  * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE
36  * MERCHANTABILITY OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR
37  * PURPOSE.
38  */
39 
40 #include "pf.h"
41 #include "pfsync.h"
42 
43 #include <sys/param.h>
44 #include <sys/systm.h>
45 #include <sys/mbuf.h>
46 #include <sys/socket.h>
47 #include <sys/kernel.h>
48 #include <sys/timeout.h>
49 #include <sys/pool.h>
50 
51 #include <net/if.h>
52 #include <net/route.h>
53 
54 #include <netinet/in.h>
55 #include <netinet/ip.h>
56 #include <netinet/in_pcb.h>
57 #include <netinet/ip_var.h>
58 #include <netinet/ip_ipip.h>
59 
60 #if NPF > 0
61 #include <net/pfvar.h>
62 #endif
63 
64 #if NPFSYNC > 0
65 #include <net/if_pfsync.h>
66 #endif
67 
68 #include <netinet/ip_ipsp.h>
69 #include <net/pfkeyv2.h>
70 
71 #ifdef DDB
72 #include <ddb/db_output.h>
73 void tdb_hashstats(void);
74 #endif
75 
76 #ifdef ENCDEBUG
77 #define	DPRINTF(x)	if (encdebug) printf x
78 #else
79 #define	DPRINTF(x)
80 #endif
81 
82 void		tdb_rehash(void);
83 void		tdb_reaper(void *);
84 void		tdb_timeout(void *);
85 void		tdb_firstuse(void *);
86 void		tdb_soft_timeout(void *);
87 void		tdb_soft_firstuse(void *);
88 int		tdb_hash(u_int32_t, union sockaddr_union *, u_int8_t);
89 
90 int ipsec_in_use = 0;
91 u_int64_t ipsec_last_added = 0;
92 int ipsec_ids_idle = 100;		/* keep free ids for 100s */
93 
94 struct pool tdb_pool;
95 
96 /* Protected by the NET_LOCK(). */
97 u_int32_t ipsec_ids_next_flow = 1;	/* may not be zero */
98 struct ipsec_ids_tree ipsec_ids_tree;
99 struct ipsec_ids_flows ipsec_ids_flows;
100 struct ipsec_policy_head ipsec_policy_head =
101     TAILQ_HEAD_INITIALIZER(ipsec_policy_head);
102 
103 void ipsp_ids_timeout(void *);
104 static inline int ipsp_ids_cmp(const struct ipsec_ids *,
105     const struct ipsec_ids *);
106 static inline int ipsp_ids_flow_cmp(const struct ipsec_ids *,
107     const struct ipsec_ids *);
108 RBT_PROTOTYPE(ipsec_ids_tree, ipsec_ids, id_node_flow, ipsp_ids_cmp);
109 RBT_PROTOTYPE(ipsec_ids_flows, ipsec_ids, id_node_id, ipsp_ids_flow_cmp);
110 RBT_GENERATE(ipsec_ids_tree, ipsec_ids, id_node_flow, ipsp_ids_cmp);
111 RBT_GENERATE(ipsec_ids_flows, ipsec_ids, id_node_id, ipsp_ids_flow_cmp);
112 
113 /*
114  * This is the proper place to define the various encapsulation transforms.
115  */
116 
117 struct xformsw xformsw[] = {
118 #ifdef IPSEC
119 {
120   .xf_type	= XF_IP4,
121   .xf_flags	= 0,
122   .xf_name	= "IPv4 Simple Encapsulation",
123   .xf_attach	= ipe4_attach,
124   .xf_init	= ipe4_init,
125   .xf_zeroize	= ipe4_zeroize,
126   .xf_input	= ipe4_input,
127   .xf_output	= ipip_output,
128 },
129 {
130   .xf_type	= XF_AH,
131   .xf_flags	= XFT_AUTH,
132   .xf_name	= "IPsec AH",
133   .xf_attach	= ah_attach,
134   .xf_init	= ah_init,
135   .xf_zeroize	= ah_zeroize,
136   .xf_input	= ah_input,
137   .xf_output	= ah_output,
138 },
139 {
140   .xf_type	= XF_ESP,
141   .xf_flags	= XFT_CONF|XFT_AUTH,
142   .xf_name	= "IPsec ESP",
143   .xf_attach	= esp_attach,
144   .xf_init	= esp_init,
145   .xf_zeroize	= esp_zeroize,
146   .xf_input	= esp_input,
147   .xf_output	= esp_output,
148 },
149 {
150   .xf_type	= XF_IPCOMP,
151   .xf_flags	= XFT_COMP,
152   .xf_name	= "IPcomp",
153   .xf_attach	= ipcomp_attach,
154   .xf_init	= ipcomp_init,
155   .xf_zeroize	= ipcomp_zeroize,
156   .xf_input	= ipcomp_input,
157   .xf_output	= ipcomp_output,
158 },
159 #endif /* IPSEC */
160 #ifdef TCP_SIGNATURE
161 {
162   .xf_type	= XF_TCPSIGNATURE,
163   .xf_flags	= XFT_AUTH,
164   .xf_name	= "TCP MD5 Signature Option, RFC 2385",
165   .xf_attach	= tcp_signature_tdb_attach,
166   .xf_init	= tcp_signature_tdb_init,
167   .xf_zeroize	= tcp_signature_tdb_zeroize,
168   .xf_input	= tcp_signature_tdb_input,
169   .xf_output	= tcp_signature_tdb_output,
170 }
171 #endif /* TCP_SIGNATURE */
172 };
173 
174 struct xformsw *xformswNXFORMSW = &xformsw[nitems(xformsw)];
175 
176 #define	TDB_HASHSIZE_INIT	32
177 
178 /* Protected by the NET_LOCK(). */
179 static SIPHASH_KEY tdbkey;
180 static struct tdb **tdbh = NULL;
181 static struct tdb **tdbdst = NULL;
182 static struct tdb **tdbsrc = NULL;
183 static u_int tdb_hashmask = TDB_HASHSIZE_INIT - 1;
184 static int tdb_count;
185 
186 /*
187  * Our hashing function needs to stir things with a non-zero random multiplier
188  * so we cannot be DoS-attacked via choosing of the data to hash.
189  */
190 int
191 tdb_hash(u_int32_t spi, union sockaddr_union *dst,
192     u_int8_t proto)
193 {
194 	SIPHASH_CTX ctx;
195 
196 	NET_ASSERT_LOCKED();
197 
198 	SipHash24_Init(&ctx, &tdbkey);
199 	SipHash24_Update(&ctx, &spi, sizeof(spi));
200 	SipHash24_Update(&ctx, &proto, sizeof(proto));
201 	SipHash24_Update(&ctx, dst, dst->sa.sa_len);
202 
203 	return (SipHash24_End(&ctx) & tdb_hashmask);
204 }
205 
206 /*
207  * Reserve an SPI; the SA is not valid yet though.  We use 0 as
208  * an error return value.
209  */
210 u_int32_t
211 reserve_spi(u_int rdomain, u_int32_t sspi, u_int32_t tspi,
212     union sockaddr_union *src, union sockaddr_union *dst,
213     u_int8_t sproto, int *errval)
214 {
215 	struct tdb *tdbp, *exists;
216 	u_int32_t spi;
217 	int nums;
218 
219 	NET_ASSERT_LOCKED();
220 
221 	/* Don't accept ranges only encompassing reserved SPIs. */
222 	if (sproto != IPPROTO_IPCOMP &&
223 	    (tspi < sspi || tspi <= SPI_RESERVED_MAX)) {
224 		(*errval) = EINVAL;
225 		return 0;
226 	}
227 	if (sproto == IPPROTO_IPCOMP && (tspi < sspi ||
228 	    tspi <= CPI_RESERVED_MAX ||
229 	    tspi >= CPI_PRIVATE_MIN)) {
230 		(*errval) = EINVAL;
231 		return 0;
232 	}
233 
234 	/* Limit the range to not include reserved areas. */
235 	if (sspi <= SPI_RESERVED_MAX)
236 		sspi = SPI_RESERVED_MAX + 1;
237 
238 	/* For IPCOMP the CPI is only 16 bits long, what a good idea.... */
239 
240 	if (sproto == IPPROTO_IPCOMP) {
241 		u_int32_t t;
242 		if (sspi >= 0x10000)
243 			sspi = 0xffff;
244 		if (tspi >= 0x10000)
245 			tspi = 0xffff;
246 		if (sspi > tspi) {
247 			t = sspi; sspi = tspi; tspi = t;
248 		}
249 	}
250 
251 	if (sspi == tspi)   /* Asking for a specific SPI. */
252 		nums = 1;
253 	else
254 		nums = 100;  /* Arbitrarily chosen */
255 
256 	/* allocate ahead of time to avoid potential sleeping race in loop */
257 	tdbp = tdb_alloc(rdomain);
258 
259 	while (nums--) {
260 		if (sspi == tspi)  /* Specific SPI asked. */
261 			spi = tspi;
262 		else    /* Range specified */
263 			spi = sspi + arc4random_uniform(tspi - sspi);
264 
265 		/* Don't allocate reserved SPIs.  */
266 		if (spi >= SPI_RESERVED_MIN && spi <= SPI_RESERVED_MAX)
267 			continue;
268 		else
269 			spi = htonl(spi);
270 
271 		/* Check whether we're using this SPI already. */
272 		exists = gettdb(rdomain, spi, dst, sproto);
273 		if (exists)
274 			continue;
275 
276 
277 		tdbp->tdb_spi = spi;
278 		memcpy(&tdbp->tdb_dst.sa, &dst->sa, dst->sa.sa_len);
279 		memcpy(&tdbp->tdb_src.sa, &src->sa, src->sa.sa_len);
280 		tdbp->tdb_sproto = sproto;
281 		tdbp->tdb_flags |= TDBF_INVALID; /* Mark SA invalid for now. */
282 		tdbp->tdb_satype = SADB_SATYPE_UNSPEC;
283 		puttdb(tdbp);
284 
285 #ifdef IPSEC
286 		/* Setup a "silent" expiration (since TDBF_INVALID's set). */
287 		if (ipsec_keep_invalid > 0) {
288 			tdbp->tdb_flags |= TDBF_TIMER;
289 			tdbp->tdb_exp_timeout = ipsec_keep_invalid;
290 			timeout_add_sec(&tdbp->tdb_timer_tmo,
291 			    ipsec_keep_invalid);
292 		}
293 #endif
294 
295 		return spi;
296 	}
297 
298 	(*errval) = EEXIST;
299 	tdb_free(tdbp);
300 	return 0;
301 }
302 
303 /*
304  * An IPSP SAID is really the concatenation of the SPI found in the
305  * packet, the destination address of the packet and the IPsec protocol.
306  * When we receive an IPSP packet, we need to look up its tunnel descriptor
307  * block, based on the SPI in the packet and the destination address (which
308  * is really one of our addresses if we received the packet!
309  */
310 struct tdb *
311 gettdb_dir(u_int rdomain, u_int32_t spi, union sockaddr_union *dst, u_int8_t proto,
312     int reverse)
313 {
314 	u_int32_t hashval;
315 	struct tdb *tdbp;
316 
317 	NET_ASSERT_LOCKED();
318 
319 	if (tdbh == NULL)
320 		return (struct tdb *) NULL;
321 
322 	hashval = tdb_hash(spi, dst, proto);
323 
324 	for (tdbp = tdbh[hashval]; tdbp != NULL; tdbp = tdbp->tdb_hnext)
325 		if ((tdbp->tdb_spi == spi) && (tdbp->tdb_sproto == proto) &&
326 		    ((!reverse && tdbp->tdb_rdomain == rdomain) ||
327 		    (reverse && tdbp->tdb_rdomain_post == rdomain)) &&
328 		    !memcmp(&tdbp->tdb_dst, dst, dst->sa.sa_len))
329 			break;
330 
331 	return tdbp;
332 }
333 
334 /*
335  * Same as gettdb() but compare SRC as well, so we
336  * use the tdbsrc[] hash table.  Setting spi to 0
337  * matches all SPIs.
338  */
339 struct tdb *
340 gettdbbysrcdst_dir(u_int rdomain, u_int32_t spi, union sockaddr_union *src,
341     union sockaddr_union *dst, u_int8_t proto, int reverse)
342 {
343 	u_int32_t hashval;
344 	struct tdb *tdbp;
345 	union sockaddr_union su_null;
346 
347 	NET_ASSERT_LOCKED();
348 
349 	if (tdbsrc == NULL)
350 		return (struct tdb *) NULL;
351 
352 	hashval = tdb_hash(0, src, proto);
353 
354 	for (tdbp = tdbsrc[hashval]; tdbp != NULL; tdbp = tdbp->tdb_snext)
355 		if (tdbp->tdb_sproto == proto &&
356 		    (spi == 0 || tdbp->tdb_spi == spi) &&
357 		    ((!reverse && tdbp->tdb_rdomain == rdomain) ||
358 		    (reverse && tdbp->tdb_rdomain_post == rdomain)) &&
359 		    ((tdbp->tdb_flags & TDBF_INVALID) == 0) &&
360 		    (tdbp->tdb_dst.sa.sa_family == AF_UNSPEC ||
361 		    !memcmp(&tdbp->tdb_dst, dst, dst->sa.sa_len)) &&
362 		    !memcmp(&tdbp->tdb_src, src, src->sa.sa_len))
363 			break;
364 
365 	if (tdbp != NULL)
366 		return (tdbp);
367 
368 	memset(&su_null, 0, sizeof(su_null));
369 	su_null.sa.sa_len = sizeof(struct sockaddr);
370 	hashval = tdb_hash(0, &su_null, proto);
371 
372 	for (tdbp = tdbsrc[hashval]; tdbp != NULL; tdbp = tdbp->tdb_snext)
373 		if (tdbp->tdb_sproto == proto &&
374 		    (spi == 0 || tdbp->tdb_spi == spi) &&
375 		    ((!reverse && tdbp->tdb_rdomain == rdomain) ||
376 		    (reverse && tdbp->tdb_rdomain_post == rdomain)) &&
377 		    ((tdbp->tdb_flags & TDBF_INVALID) == 0) &&
378 		    (tdbp->tdb_dst.sa.sa_family == AF_UNSPEC ||
379 		    !memcmp(&tdbp->tdb_dst, dst, dst->sa.sa_len)) &&
380 		    tdbp->tdb_src.sa.sa_family == AF_UNSPEC)
381 			break;
382 
383 	return (tdbp);
384 }
385 
386 /*
387  * Check that IDs match. Return true if so. The t* range of
388  * arguments contains information from TDBs; the p* range of
389  * arguments contains information from policies or already
390  * established TDBs.
391  */
392 int
393 ipsp_aux_match(struct tdb *tdb,
394     struct ipsec_ids *ids,
395     struct sockaddr_encap *pfilter,
396     struct sockaddr_encap *pfiltermask)
397 {
398 	if (ids != NULL)
399 		if (tdb->tdb_ids == NULL ||
400 		    !ipsp_ids_match(tdb->tdb_ids, ids))
401 			return 0;
402 
403 	/* Check for filter matches. */
404 	if (pfilter != NULL && pfiltermask != NULL &&
405 	    tdb->tdb_filter.sen_type) {
406 		/*
407 		 * XXX We should really be doing a subnet-check (see
408 		 * whether the TDB-associated filter is a subset
409 		 * of the policy's. For now, an exact match will solve
410 		 * most problems (all this will do is make every
411 		 * policy get its own SAs).
412 		 */
413 		if (memcmp(&tdb->tdb_filter, pfilter,
414 		    sizeof(struct sockaddr_encap)) ||
415 		    memcmp(&tdb->tdb_filtermask, pfiltermask,
416 		    sizeof(struct sockaddr_encap)))
417 			return 0;
418 	}
419 
420 	return 1;
421 }
422 
423 /*
424  * Get an SA given the remote address, the security protocol type, and
425  * the desired IDs.
426  */
427 struct tdb *
428 gettdbbydst(u_int rdomain, union sockaddr_union *dst, u_int8_t sproto,
429     struct ipsec_ids *ids,
430     struct sockaddr_encap *filter, struct sockaddr_encap *filtermask)
431 {
432 	u_int32_t hashval;
433 	struct tdb *tdbp;
434 
435 	NET_ASSERT_LOCKED();
436 
437 	if (tdbdst == NULL)
438 		return (struct tdb *) NULL;
439 
440 	hashval = tdb_hash(0, dst, sproto);
441 
442 	for (tdbp = tdbdst[hashval]; tdbp != NULL; tdbp = tdbp->tdb_dnext)
443 		if ((tdbp->tdb_sproto == sproto) &&
444 		    (tdbp->tdb_rdomain == rdomain) &&
445 		    ((tdbp->tdb_flags & TDBF_INVALID) == 0) &&
446 		    (!memcmp(&tdbp->tdb_dst, dst, dst->sa.sa_len))) {
447 			/* Do IDs match ? */
448 			if (!ipsp_aux_match(tdbp, ids, filter, filtermask))
449 				continue;
450 			break;
451 		}
452 
453 	return tdbp;
454 }
455 
456 /*
457  * Get an SA given the source address, the security protocol type, and
458  * the desired IDs.
459  */
460 struct tdb *
461 gettdbbysrc(u_int rdomain, union sockaddr_union *src, u_int8_t sproto,
462     struct ipsec_ids *ids,
463     struct sockaddr_encap *filter, struct sockaddr_encap *filtermask)
464 {
465 	u_int32_t hashval;
466 	struct tdb *tdbp;
467 
468 	NET_ASSERT_LOCKED();
469 
470 	if (tdbsrc == NULL)
471 		return (struct tdb *) NULL;
472 
473 	hashval = tdb_hash(0, src, sproto);
474 
475 	for (tdbp = tdbsrc[hashval]; tdbp != NULL; tdbp = tdbp->tdb_snext)
476 		if ((tdbp->tdb_sproto == sproto) &&
477 		    (tdbp->tdb_rdomain == rdomain) &&
478 		    ((tdbp->tdb_flags & TDBF_INVALID) == 0) &&
479 		    (!memcmp(&tdbp->tdb_src, src, src->sa.sa_len))) {
480 			/* Check whether IDs match */
481 			if (!ipsp_aux_match(tdbp, ids, filter,
482 			    filtermask))
483 				continue;
484 			break;
485 		}
486 
487 	return tdbp;
488 }
489 
490 #if DDB
491 
492 #define NBUCKETS 16
493 void
494 tdb_hashstats(void)
495 {
496 	int i, cnt, buckets[NBUCKETS];
497 	struct tdb *tdbp;
498 
499 	if (tdbh == NULL) {
500 		db_printf("no tdb hash table\n");
501 		return;
502 	}
503 
504 	memset(buckets, 0, sizeof(buckets));
505 	for (i = 0; i <= tdb_hashmask; i++) {
506 		cnt = 0;
507 		for (tdbp = tdbh[i]; cnt < NBUCKETS - 1 && tdbp != NULL;
508 		    tdbp = tdbp->tdb_hnext)
509 			cnt++;
510 		buckets[cnt]++;
511 	}
512 
513 	db_printf("tdb cnt\t\tbucket cnt\n");
514 	for (i = 0; i < NBUCKETS; i++)
515 		if (buckets[i] > 0)
516 			db_printf("%d%s\t\t%d\n", i, i == NBUCKETS - 1 ?
517 			    "+" : "", buckets[i]);
518 }
519 #endif	/* DDB */
520 
521 int
522 tdb_walk(u_int rdomain, int (*walker)(struct tdb *, void *, int), void *arg)
523 {
524 	int i, rval = 0;
525 	struct tdb *tdbp, *next;
526 
527 	NET_ASSERT_LOCKED();
528 
529 	if (tdbh == NULL)
530 		return ENOENT;
531 
532 	for (i = 0; i <= tdb_hashmask; i++)
533 		for (tdbp = tdbh[i]; rval == 0 && tdbp != NULL; tdbp = next) {
534 			next = tdbp->tdb_hnext;
535 
536 			if (rdomain != tdbp->tdb_rdomain)
537 				continue;
538 
539 			if (i == tdb_hashmask && next == NULL)
540 				rval = walker(tdbp, (void *)arg, 1);
541 			else
542 				rval = walker(tdbp, (void *)arg, 0);
543 		}
544 
545 	return rval;
546 }
547 
548 void
549 tdb_timeout(void *v)
550 {
551 	struct tdb *tdb = v;
552 
553 	NET_LOCK();
554 	if (tdb->tdb_flags & TDBF_TIMER) {
555 		/* If it's an "invalid" TDB do a silent expiration. */
556 		if (!(tdb->tdb_flags & TDBF_INVALID))
557 			pfkeyv2_expire(tdb, SADB_EXT_LIFETIME_HARD);
558 		tdb_delete(tdb);
559 	}
560 	NET_UNLOCK();
561 }
562 
563 void
564 tdb_firstuse(void *v)
565 {
566 	struct tdb *tdb = v;
567 
568 	NET_LOCK();
569 	if (tdb->tdb_flags & TDBF_SOFT_FIRSTUSE) {
570 		/* If the TDB hasn't been used, don't renew it. */
571 		if (tdb->tdb_first_use != 0)
572 			pfkeyv2_expire(tdb, SADB_EXT_LIFETIME_HARD);
573 		tdb_delete(tdb);
574 	}
575 	NET_UNLOCK();
576 }
577 
578 void
579 tdb_soft_timeout(void *v)
580 {
581 	struct tdb *tdb = v;
582 
583 	NET_LOCK();
584 	if (tdb->tdb_flags & TDBF_SOFT_TIMER) {
585 		/* Soft expirations. */
586 		pfkeyv2_expire(tdb, SADB_EXT_LIFETIME_SOFT);
587 		tdb->tdb_flags &= ~TDBF_SOFT_TIMER;
588 	}
589 	NET_UNLOCK();
590 }
591 
592 void
593 tdb_soft_firstuse(void *v)
594 {
595 	struct tdb *tdb = v;
596 
597 	NET_LOCK();
598 	if (tdb->tdb_flags & TDBF_SOFT_FIRSTUSE) {
599 		/* If the TDB hasn't been used, don't renew it. */
600 		if (tdb->tdb_first_use != 0)
601 			pfkeyv2_expire(tdb, SADB_EXT_LIFETIME_SOFT);
602 		tdb->tdb_flags &= ~TDBF_SOFT_FIRSTUSE;
603 	}
604 	NET_UNLOCK();
605 }
606 
607 void
608 tdb_rehash(void)
609 {
610 	struct tdb **new_tdbh, **new_tdbdst, **new_srcaddr, *tdbp, *tdbnp;
611 	u_int i, old_hashmask = tdb_hashmask;
612 	u_int32_t hashval;
613 
614 	NET_ASSERT_LOCKED();
615 
616 	tdb_hashmask = (tdb_hashmask << 1) | 1;
617 
618 	arc4random_buf(&tdbkey, sizeof(tdbkey));
619 	new_tdbh = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB,
620 	    M_WAITOK | M_ZERO);
621 	new_tdbdst = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB,
622 	    M_WAITOK | M_ZERO);
623 	new_srcaddr = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB,
624 	    M_WAITOK | M_ZERO);
625 
626 	for (i = 0; i <= old_hashmask; i++) {
627 		for (tdbp = tdbh[i]; tdbp != NULL; tdbp = tdbnp) {
628 			tdbnp = tdbp->tdb_hnext;
629 			hashval = tdb_hash(tdbp->tdb_spi, &tdbp->tdb_dst,
630 			    tdbp->tdb_sproto);
631 			tdbp->tdb_hnext = new_tdbh[hashval];
632 			new_tdbh[hashval] = tdbp;
633 		}
634 
635 		for (tdbp = tdbdst[i]; tdbp != NULL; tdbp = tdbnp) {
636 			tdbnp = tdbp->tdb_dnext;
637 			hashval = tdb_hash(0, &tdbp->tdb_dst, tdbp->tdb_sproto);
638 			tdbp->tdb_dnext = new_tdbdst[hashval];
639 			new_tdbdst[hashval] = tdbp;
640 		}
641 
642 		for (tdbp = tdbsrc[i]; tdbp != NULL; tdbp = tdbnp) {
643 			tdbnp = tdbp->tdb_snext;
644 			hashval = tdb_hash(0, &tdbp->tdb_src, tdbp->tdb_sproto);
645 			tdbp->tdb_snext = new_srcaddr[hashval];
646 			new_srcaddr[hashval] = tdbp;
647 		}
648 	}
649 
650 	free(tdbh, M_TDB, 0);
651 	tdbh = new_tdbh;
652 
653 	free(tdbdst, M_TDB, 0);
654 	tdbdst = new_tdbdst;
655 
656 	free(tdbsrc, M_TDB, 0);
657 	tdbsrc = new_srcaddr;
658 }
659 
660 /*
661  * Add TDB in the hash table.
662  */
663 void
664 puttdb(struct tdb *tdbp)
665 {
666 	u_int32_t hashval;
667 
668 	NET_ASSERT_LOCKED();
669 
670 	if (tdbh == NULL) {
671 		arc4random_buf(&tdbkey, sizeof(tdbkey));
672 		tdbh = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *),
673 		    M_TDB, M_WAITOK | M_ZERO);
674 		tdbdst = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *),
675 		    M_TDB, M_WAITOK | M_ZERO);
676 		tdbsrc = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *),
677 		    M_TDB, M_WAITOK | M_ZERO);
678 	}
679 
680 	hashval = tdb_hash(tdbp->tdb_spi, &tdbp->tdb_dst, tdbp->tdb_sproto);
681 
682 	/*
683 	 * Rehash if this tdb would cause a bucket to have more than
684 	 * two items and if the number of tdbs exceed 10% of the
685 	 * bucket count.  This number is arbitrarily chosen and is
686 	 * just a measure to not keep rehashing when adding and
687 	 * removing tdbs which happens to always end up in the same
688 	 * bucket, which is not uncommon when doing manual keying.
689 	 */
690 	if (tdbh[hashval] != NULL && tdbh[hashval]->tdb_hnext != NULL &&
691 	    tdb_count * 10 > tdb_hashmask + 1) {
692 		tdb_rehash();
693 		hashval = tdb_hash(tdbp->tdb_spi, &tdbp->tdb_dst,
694 		    tdbp->tdb_sproto);
695 	}
696 
697 	tdbp->tdb_hnext = tdbh[hashval];
698 	tdbh[hashval] = tdbp;
699 
700 	hashval = tdb_hash(0, &tdbp->tdb_dst, tdbp->tdb_sproto);
701 	tdbp->tdb_dnext = tdbdst[hashval];
702 	tdbdst[hashval] = tdbp;
703 
704 	hashval = tdb_hash(0, &tdbp->tdb_src, tdbp->tdb_sproto);
705 	tdbp->tdb_snext = tdbsrc[hashval];
706 	tdbsrc[hashval] = tdbp;
707 
708 	tdb_count++;
709 #ifdef IPSEC
710 	if ((tdbp->tdb_flags & (TDBF_INVALID|TDBF_TUNNELING)) == TDBF_TUNNELING)
711 		ipsecstat_inc(ipsec_tunnels);
712 #endif /* IPSEC */
713 
714 	ipsec_last_added = getuptime();
715 }
716 
717 void
718 tdb_unlink(struct tdb *tdbp)
719 {
720 	struct tdb *tdbpp;
721 	u_int32_t hashval;
722 
723 	NET_ASSERT_LOCKED();
724 
725 	if (tdbh == NULL)
726 		return;
727 
728 	hashval = tdb_hash(tdbp->tdb_spi, &tdbp->tdb_dst, tdbp->tdb_sproto);
729 
730 	if (tdbh[hashval] == tdbp) {
731 		tdbh[hashval] = tdbp->tdb_hnext;
732 	} else {
733 		for (tdbpp = tdbh[hashval]; tdbpp != NULL;
734 		    tdbpp = tdbpp->tdb_hnext) {
735 			if (tdbpp->tdb_hnext == tdbp) {
736 				tdbpp->tdb_hnext = tdbp->tdb_hnext;
737 				break;
738 			}
739 		}
740 	}
741 
742 	tdbp->tdb_hnext = NULL;
743 
744 	hashval = tdb_hash(0, &tdbp->tdb_dst, tdbp->tdb_sproto);
745 
746 	if (tdbdst[hashval] == tdbp) {
747 		tdbdst[hashval] = tdbp->tdb_dnext;
748 	} else {
749 		for (tdbpp = tdbdst[hashval]; tdbpp != NULL;
750 		    tdbpp = tdbpp->tdb_dnext) {
751 			if (tdbpp->tdb_dnext == tdbp) {
752 				tdbpp->tdb_dnext = tdbp->tdb_dnext;
753 				break;
754 			}
755 		}
756 	}
757 
758 	tdbp->tdb_dnext = NULL;
759 
760 	hashval = tdb_hash(0, &tdbp->tdb_src, tdbp->tdb_sproto);
761 
762 	if (tdbsrc[hashval] == tdbp) {
763 		tdbsrc[hashval] = tdbp->tdb_snext;
764 	}
765 	else {
766 		for (tdbpp = tdbsrc[hashval]; tdbpp != NULL;
767 		    tdbpp = tdbpp->tdb_snext) {
768 			if (tdbpp->tdb_snext == tdbp) {
769 				tdbpp->tdb_snext = tdbp->tdb_snext;
770 				break;
771 			}
772 		}
773 	}
774 
775 	tdbp->tdb_snext = NULL;
776 	tdb_count--;
777 #ifdef IPSEC
778 	if ((tdbp->tdb_flags & (TDBF_INVALID|TDBF_TUNNELING)) ==
779 	    TDBF_TUNNELING) {
780 		ipsecstat_dec(ipsec_tunnels);
781 		ipsecstat_inc(ipsec_prevtunnels);
782 	}
783 #endif /* IPSEC */
784 }
785 
786 void
787 tdb_delete(struct tdb *tdbp)
788 {
789 	NET_ASSERT_LOCKED();
790 
791 	tdb_unlink(tdbp);
792 	tdb_free(tdbp);
793 }
794 
795 /*
796  * Allocate a TDB and initialize a few basic fields.
797  */
798 struct tdb *
799 tdb_alloc(u_int rdomain)
800 {
801 	struct tdb *tdbp;
802 	static int initialized = 0;
803 
804 	NET_ASSERT_LOCKED();
805 
806 	if (!initialized) {
807 		pool_init(&tdb_pool, sizeof(struct tdb), 0, IPL_SOFTNET, 0,
808 		    "tdb", NULL);
809 		initialized = 1;
810 	}
811 	tdbp = pool_get(&tdb_pool, PR_WAITOK | PR_ZERO);
812 
813 	TAILQ_INIT(&tdbp->tdb_policy_head);
814 
815 	/* Record establishment time. */
816 	tdbp->tdb_established = gettime();
817 
818 	/* Save routing domain */
819 	tdbp->tdb_rdomain = rdomain;
820 	tdbp->tdb_rdomain_post = rdomain;
821 
822 	/* Initialize timeouts. */
823 	timeout_set_proc(&tdbp->tdb_timer_tmo, tdb_timeout, tdbp);
824 	timeout_set_proc(&tdbp->tdb_first_tmo, tdb_firstuse, tdbp);
825 	timeout_set_proc(&tdbp->tdb_stimer_tmo, tdb_soft_timeout, tdbp);
826 	timeout_set_proc(&tdbp->tdb_sfirst_tmo, tdb_soft_firstuse, tdbp);
827 
828 	return tdbp;
829 }
830 
831 void
832 tdb_free(struct tdb *tdbp)
833 {
834 	struct ipsec_policy *ipo;
835 
836 	NET_ASSERT_LOCKED();
837 
838 	if (tdbp->tdb_xform) {
839 		(*(tdbp->tdb_xform->xf_zeroize))(tdbp);
840 		tdbp->tdb_xform = NULL;
841 	}
842 
843 #if NPFSYNC > 0
844 	/* Cleanup pfsync references */
845 	pfsync_delete_tdb(tdbp);
846 #endif
847 
848 	/* Cleanup SPD references. */
849 	for (ipo = TAILQ_FIRST(&tdbp->tdb_policy_head); ipo;
850 	    ipo = TAILQ_FIRST(&tdbp->tdb_policy_head))	{
851 		TAILQ_REMOVE(&tdbp->tdb_policy_head, ipo, ipo_tdb_next);
852 		ipo->ipo_tdb = NULL;
853 		ipo->ipo_last_searched = 0; /* Force a re-search. */
854 	}
855 
856 	if (tdbp->tdb_ids) {
857 		ipsp_ids_free(tdbp->tdb_ids);
858 		tdbp->tdb_ids = NULL;
859 	}
860 
861 #if NPF > 0
862 	if (tdbp->tdb_tag) {
863 		pf_tag_unref(tdbp->tdb_tag);
864 		tdbp->tdb_tag = 0;
865 	}
866 #endif
867 
868 	if ((tdbp->tdb_onext) && (tdbp->tdb_onext->tdb_inext == tdbp))
869 		tdbp->tdb_onext->tdb_inext = NULL;
870 
871 	if ((tdbp->tdb_inext) && (tdbp->tdb_inext->tdb_onext == tdbp))
872 		tdbp->tdb_inext->tdb_onext = NULL;
873 
874 	/* Remove expiration timeouts. */
875 	tdbp->tdb_flags &= ~(TDBF_FIRSTUSE | TDBF_SOFT_FIRSTUSE | TDBF_TIMER |
876 	    TDBF_SOFT_TIMER);
877 	timeout_del(&tdbp->tdb_timer_tmo);
878 	timeout_del(&tdbp->tdb_first_tmo);
879 	timeout_del(&tdbp->tdb_stimer_tmo);
880 	timeout_del(&tdbp->tdb_sfirst_tmo);
881 
882 	timeout_set_proc(&tdbp->tdb_timer_tmo, tdb_reaper, tdbp);
883 	timeout_add(&tdbp->tdb_timer_tmo, 0);
884 }
885 
886 void
887 tdb_reaper(void *xtdbp)
888 {
889 	struct tdb *tdbp = xtdbp;
890 
891 	pool_put(&tdb_pool, tdbp);
892 }
893 
894 /*
895  * Do further initializations of a TDB.
896  */
897 int
898 tdb_init(struct tdb *tdbp, u_int16_t alg, struct ipsecinit *ii)
899 {
900 	struct xformsw *xsp;
901 	int err;
902 #ifdef ENCDEBUG
903 	char buf[INET6_ADDRSTRLEN];
904 #endif
905 
906 	for (xsp = xformsw; xsp < xformswNXFORMSW; xsp++) {
907 		if (xsp->xf_type == alg) {
908 			err = (*(xsp->xf_init))(tdbp, xsp, ii);
909 			return err;
910 		}
911 	}
912 
913 	DPRINTF(("%s: no alg %d for spi %08x, addr %s, proto %d\n", __func__,
914 	    alg, ntohl(tdbp->tdb_spi), ipsp_address(&tdbp->tdb_dst, buf,
915 	    sizeof(buf)), tdbp->tdb_sproto));
916 
917 	return EINVAL;
918 }
919 
920 #ifdef ENCDEBUG
921 /* Return a printable string for the address. */
922 const char *
923 ipsp_address(union sockaddr_union *sa, char *buf, socklen_t size)
924 {
925 	switch (sa->sa.sa_family) {
926 	case AF_INET:
927 		return inet_ntop(AF_INET, &sa->sin.sin_addr,
928 		    buf, (size_t)size);
929 
930 #ifdef INET6
931 	case AF_INET6:
932 		return inet_ntop(AF_INET6, &sa->sin6.sin6_addr,
933 		    buf, (size_t)size);
934 #endif /* INET6 */
935 
936 	default:
937 		return "(unknown address family)";
938 	}
939 }
940 #endif /* ENCDEBUG */
941 
942 /* Check whether an IP{4,6} address is unspecified. */
943 int
944 ipsp_is_unspecified(union sockaddr_union addr)
945 {
946 	switch (addr.sa.sa_family) {
947 	case AF_INET:
948 		if (addr.sin.sin_addr.s_addr == INADDR_ANY)
949 			return 1;
950 		else
951 			return 0;
952 
953 #ifdef INET6
954 	case AF_INET6:
955 		if (IN6_IS_ADDR_UNSPECIFIED(&addr.sin6.sin6_addr))
956 			return 1;
957 		else
958 			return 0;
959 #endif /* INET6 */
960 
961 	case 0: /* No family set. */
962 	default:
963 		return 1;
964 	}
965 }
966 
967 int
968 ipsp_ids_match(struct ipsec_ids *a, struct ipsec_ids *b)
969 {
970 	return a == b;
971 }
972 
973 struct ipsec_ids *
974 ipsp_ids_insert(struct ipsec_ids *ids)
975 {
976 	struct ipsec_ids *found;
977 	u_int32_t start_flow;
978 
979 	NET_ASSERT_LOCKED();
980 
981 	found = RBT_INSERT(ipsec_ids_tree, &ipsec_ids_tree, ids);
982 	if (found) {
983 		/* if refcount was zero, then timeout is running */
984 		if (found->id_refcount++ == 0)
985 			timeout_del(&found->id_timeout);
986 		DPRINTF(("%s: ids %p count %d\n", __func__,
987 		    found, found->id_refcount));
988 		return found;
989 	}
990 	ids->id_flow = start_flow = ipsec_ids_next_flow;
991 	if (++ipsec_ids_next_flow == 0)
992 		ipsec_ids_next_flow = 1;
993 	while (RBT_INSERT(ipsec_ids_flows, &ipsec_ids_flows, ids) != NULL) {
994 		ids->id_flow = ipsec_ids_next_flow;
995 		if (++ipsec_ids_next_flow == 0)
996 			ipsec_ids_next_flow = 1;
997 		if (ipsec_ids_next_flow == start_flow) {
998 			DPRINTF(("ipsec_ids_next_flow exhausted %u\n",
999 			    ipsec_ids_next_flow));
1000 			return NULL;
1001 		}
1002 	}
1003 	ids->id_refcount = 1;
1004 	DPRINTF(("%s: new ids %p flow %u\n", __func__, ids, ids->id_flow));
1005 	timeout_set_proc(&ids->id_timeout, ipsp_ids_timeout, ids);
1006 	return ids;
1007 }
1008 
1009 struct ipsec_ids *
1010 ipsp_ids_lookup(u_int32_t ipsecflowinfo)
1011 {
1012 	struct ipsec_ids	key;
1013 
1014 	NET_ASSERT_LOCKED();
1015 
1016 	key.id_flow = ipsecflowinfo;
1017 	return RBT_FIND(ipsec_ids_flows, &ipsec_ids_flows, &key);
1018 }
1019 
1020 /* free ids only from delayed timeout */
1021 void
1022 ipsp_ids_timeout(void *arg)
1023 {
1024 	struct ipsec_ids *ids = arg;
1025 
1026 	DPRINTF(("%s: ids %p count %d\n", __func__, ids, ids->id_refcount));
1027 	KASSERT(ids->id_refcount == 0);
1028 
1029 	NET_LOCK();
1030 	RBT_REMOVE(ipsec_ids_tree, &ipsec_ids_tree, ids);
1031 	RBT_REMOVE(ipsec_ids_flows, &ipsec_ids_flows, ids);
1032 	free(ids->id_local, M_CREDENTIALS, 0);
1033 	free(ids->id_remote, M_CREDENTIALS, 0);
1034 	free(ids, M_CREDENTIALS, 0);
1035 	NET_UNLOCK();
1036 }
1037 
1038 /* decrements refcount, actual free happens in timeout */
1039 void
1040 ipsp_ids_free(struct ipsec_ids *ids)
1041 {
1042 	/*
1043 	 * If the refcount becomes zero, then a timeout is started. This
1044 	 * timeout must be cancelled if refcount is increased from zero.
1045 	 */
1046 	DPRINTF(("%s: ids %p count %d\n", __func__, ids, ids->id_refcount));
1047 	KASSERT(ids->id_refcount > 0);
1048 	if (--ids->id_refcount == 0)
1049 		timeout_add_sec(&ids->id_timeout, ipsec_ids_idle);
1050 }
1051 
1052 static int
1053 ipsp_id_cmp(struct ipsec_id *a, struct ipsec_id *b)
1054 {
1055 	if (a->type > b->type)
1056 		return 1;
1057 	if (a->type < b->type)
1058 		return -1;
1059 	if (a->len > b->len)
1060 		return 1;
1061 	if (a->len < b->len)
1062 		return -1;
1063 	return memcmp(a + 1, b + 1, a->len);
1064 }
1065 
1066 static inline int
1067 ipsp_ids_cmp(const struct ipsec_ids *a, const struct ipsec_ids *b)
1068 {
1069 	int ret;
1070 
1071 	ret = ipsp_id_cmp(a->id_remote, b->id_remote);
1072 	if (ret != 0)
1073 		return ret;
1074 	return ipsp_id_cmp(a->id_local, b->id_local);
1075 }
1076 
1077 static inline int
1078 ipsp_ids_flow_cmp(const struct ipsec_ids *a, const struct ipsec_ids *b)
1079 {
1080 	if (a->id_flow > b->id_flow)
1081 		return 1;
1082 	if (a->id_flow < b->id_flow)
1083 		return -1;
1084 	return 0;
1085 }
1086