1 /* $NetBSD: rpst.c,v 1.12 2021/11/14 20:51:57 andvar Exp $ */
2
3 /*-
4 * Copyright (c)2009 YAMAMOTO Takashi,
5 * All rights reserved.
6 *
7 * Redistribution and use in source and binary forms, with or without
8 * modification, are permitted provided that the following conditions
9 * are met:
10 * 1. Redistributions of source code must retain the above copyright
11 * notice, this list of conditions and the following disclaimer.
12 * 2. Redistributions in binary form must reproduce the above copyright
13 * notice, this list of conditions and the following disclaimer in the
14 * documentation and/or other materials provided with the distribution.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
20 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
22 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
23 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
24 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
25 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
26 * SUCH DAMAGE.
27 */
28
29 /*
30 * radix priority search tree
31 *
32 * described in:
33 * SIAM J. COMPUT.
34 * Vol. 14, No. 2, May 1985
35 * PRIORITY SEARCH TREES
36 * EDWARD M. McCREIGHT
37 *
38 * ideas from linux:
39 * - grow tree height on-demand.
40 * - allow duplicated X values. in that case, we act as a heap.
41 */
42
43 #include <sys/cdefs.h>
44
45 #if defined(_KERNEL) || defined(_STANDALONE)
46 __KERNEL_RCSID(0, "$NetBSD: rpst.c,v 1.12 2021/11/14 20:51:57 andvar Exp $");
47 #include <sys/param.h>
48 #include <lib/libkern/libkern.h>
49 #if defined(_STANDALONE)
50 #include <lib/libsa/stand.h>
51 #endif /* defined(_STANDALONE) */
52 #else /* defined(_KERNEL) || defined(_STANDALONE) */
53 __RCSID("$NetBSD: rpst.c,v 1.12 2021/11/14 20:51:57 andvar Exp $");
54 #include <assert.h>
55 #include <stdbool.h>
56 #include <string.h>
57 #if 1
58 #define KASSERT assert
59 #else
60 #define KASSERT(a)
61 #endif
62 #endif /* defined(_KERNEL) || defined(_STANDALONE) */
63
64 #include <sys/rpst.h>
65
66 /*
67 * rpst_init_tree: initialize a tree.
68 */
69
70 void
rpst_init_tree(struct rpst_tree * t)71 rpst_init_tree(struct rpst_tree *t)
72 {
73
74 t->t_root = NULL;
75 t->t_height = 0;
76 }
77
78 /*
79 * rpst_height2max: calculate the maximum index which can be handled by
80 * a tree with the given height.
81 *
82 * 0 ... 0x0000000000000001
83 * 1 ... 0x0000000000000003
84 * 2 ... 0x0000000000000007
85 * 3 ... 0x000000000000000f
86 *
87 * 31 ... 0x00000000ffffffff
88 *
89 * 63 ... 0xffffffffffffffff
90 */
91
92 static uint64_t
rpst_height2max(unsigned int height)93 rpst_height2max(unsigned int height)
94 {
95
96 KASSERT(height < 64);
97 if (height == 63) {
98 return UINT64_MAX;
99 }
100 return (UINT64_C(1) << (height + 1)) - 1;
101 }
102
103 /*
104 * rpst_level2mask: calculate the mask for the given level in the tree.
105 *
106 * the mask used to index root's children is level 0.
107 */
108
109 static uint64_t
rpst_level2mask(const struct rpst_tree * t,unsigned int level)110 rpst_level2mask(const struct rpst_tree *t, unsigned int level)
111 {
112 uint64_t mask;
113
114 if (t->t_height < level) {
115 mask = 0;
116 } else {
117 mask = UINT64_C(1) << (t->t_height - level);
118 }
119 return mask;
120 }
121
122 /*
123 * rpst_startmask: calculate the mask for the start of a search.
124 * (ie. the mask for the top-most bit)
125 */
126
127 static uint64_t
rpst_startmask(const struct rpst_tree * t)128 rpst_startmask(const struct rpst_tree *t)
129 {
130 const uint64_t mask = rpst_level2mask(t, 0);
131
132 KASSERT((mask | (mask - 1)) == rpst_height2max(t->t_height));
133 return mask;
134 }
135
136 /*
137 * rpst_update_parents: update n_parent of children
138 */
139
140 static inline void
rpst_update_parents(struct rpst_node * n)141 rpst_update_parents(struct rpst_node *n)
142 {
143 int i;
144
145 for (i = 0; i < 2; i++) {
146 if (n->n_children[i] != NULL) {
147 n->n_children[i]->n_parent = n;
148 }
149 }
150 }
151
152 /*
153 * rpst_enlarge_tree: enlarge tree so that 'idx' can be stored
154 */
155
156 static void
rpst_enlarge_tree(struct rpst_tree * t,uint64_t idx)157 rpst_enlarge_tree(struct rpst_tree *t, uint64_t idx)
158 {
159
160 while (idx > rpst_height2max(t->t_height)) {
161 struct rpst_node *n = t->t_root;
162
163 if (n != NULL) {
164 rpst_remove_node(t, n);
165 memset(&n->n_children, 0, sizeof(n->n_children));
166 n->n_children[0] = t->t_root;
167 t->t_root->n_parent = n;
168 t->t_root = n;
169 n->n_parent = NULL;
170 }
171 t->t_height++;
172 }
173 }
174
175 /*
176 * rpst_insert_node1: a helper for rpst_insert_node.
177 */
178
179 static struct rpst_node *
rpst_insert_node1(struct rpst_node ** where,struct rpst_node * n,uint64_t mask)180 rpst_insert_node1(struct rpst_node **where, struct rpst_node *n, uint64_t mask)
181 {
182 struct rpst_node *parent;
183 struct rpst_node *cur;
184 unsigned int idx;
185
186 KASSERT((n->n_x & ((-mask) << 1)) == 0);
187 parent = NULL;
188 next:
189 cur = *where;
190 if (cur == NULL) {
191 n->n_parent = parent;
192 memset(&n->n_children, 0, sizeof(n->n_children));
193 *where = n;
194 return NULL;
195 }
196 KASSERT(cur->n_parent == parent);
197 if (n->n_y == cur->n_y && n->n_x == cur->n_x) {
198 return cur;
199 }
200 if (n->n_y < cur->n_y) {
201 /*
202 * swap cur and n.
203 * note that n is not in tree.
204 */
205 memcpy(n->n_children, cur->n_children, sizeof(n->n_children));
206 n->n_parent = cur->n_parent;
207 rpst_update_parents(n);
208 *where = n;
209 n = cur;
210 cur = *where;
211 }
212 KASSERT(*where == cur);
213 idx = (n->n_x & mask) != 0;
214 where = &cur->n_children[idx];
215 parent = cur;
216 KASSERT((*where) == NULL || ((((*where)->n_x & mask) != 0) == idx));
217 KASSERT((*where) == NULL || (*where)->n_y >= cur->n_y);
218 mask >>= 1;
219 goto next;
220 }
221
222 /*
223 * rpst_insert_node: insert a node into the tree.
224 *
225 * => return NULL on success.
226 * => if a duplicated node (a node with the same X,Y pair as ours) is found,
227 * return the node. in that case, the tree is intact.
228 */
229
230 struct rpst_node *
rpst_insert_node(struct rpst_tree * t,struct rpst_node * n)231 rpst_insert_node(struct rpst_tree *t, struct rpst_node *n)
232 {
233
234 rpst_enlarge_tree(t, n->n_x);
235 return rpst_insert_node1(&t->t_root, n, rpst_startmask(t));
236 }
237
238 /*
239 * rpst_find_pptr: find a pointer to the given node.
240 *
241 * also, return the parent node via parentp. (NULL for the root node.)
242 */
243
244 static inline struct rpst_node **
rpst_find_pptr(struct rpst_tree * t,struct rpst_node * n,struct rpst_node ** parentp)245 rpst_find_pptr(struct rpst_tree *t, struct rpst_node *n,
246 struct rpst_node **parentp)
247 {
248 struct rpst_node * const parent = n->n_parent;
249 unsigned int i;
250
251 *parentp = parent;
252 if (parent == NULL) {
253 return &t->t_root;
254 }
255 for (i = 0; i < 2 - 1; i++) {
256 if (parent->n_children[i] == n) {
257 break;
258 }
259 }
260 KASSERT(parent->n_children[i] == n);
261 return &parent->n_children[i];
262 }
263
264 /*
265 * rpst_remove_node_at: remove a node at *where.
266 */
267
268 static void
rpst_remove_node_at(struct rpst_node * parent,struct rpst_node ** where,struct rpst_node * cur)269 rpst_remove_node_at(struct rpst_node *parent, struct rpst_node **where,
270 struct rpst_node *cur)
271 {
272 struct rpst_node *tmp[2];
273 struct rpst_node *selected;
274 unsigned int selected_idx = 0; /* XXX gcc */
275 unsigned int i;
276
277 KASSERT(cur != NULL);
278 KASSERT(parent == cur->n_parent);
279 next:
280 selected = NULL;
281 for (i = 0; i < 2; i++) {
282 struct rpst_node *c;
283
284 c = cur->n_children[i];
285 KASSERT(c == NULL || c->n_parent == cur);
286 if (selected == NULL || (c != NULL && c->n_y < selected->n_y)) {
287 selected = c;
288 selected_idx = i;
289 }
290 }
291 /*
292 * now we have:
293 *
294 * parent
295 * \ <- where
296 * cur
297 * / \
298 * A selected
299 * / \
300 * B C
301 */
302 *where = selected;
303 if (selected == NULL) {
304 return;
305 }
306 /*
307 * swap selected->n_children and cur->n_children.
308 */
309 memcpy(tmp, selected->n_children, sizeof(tmp));
310 memcpy(selected->n_children, cur->n_children, sizeof(tmp));
311 memcpy(cur->n_children, tmp, sizeof(tmp));
312 rpst_update_parents(cur);
313 rpst_update_parents(selected);
314 selected->n_parent = parent;
315 /*
316 * parent
317 * \ <- where
318 * selected
319 * / \
320 * A selected
321 *
322 * cur
323 * / \
324 * B C
325 */
326 where = &selected->n_children[selected_idx];
327 /*
328 * parent
329 * \
330 * selected
331 * / \ <- where
332 * A selected (*)
333 *
334 * cur (**)
335 * / \
336 * B C
337 *
338 * (*) this 'selected' will be overwritten in the next iteration.
339 * (**) cur->n_parent is bogus.
340 */
341 parent = selected;
342 goto next;
343 }
344
345 /*
346 * rpst_remove_node: remove a node from the tree.
347 */
348
349 void
rpst_remove_node(struct rpst_tree * t,struct rpst_node * n)350 rpst_remove_node(struct rpst_tree *t, struct rpst_node *n)
351 {
352 struct rpst_node *parent;
353 struct rpst_node **where;
354
355 where = rpst_find_pptr(t, n, &parent);
356 rpst_remove_node_at(parent, where, n);
357 }
358
359 static bool __unused
rpst_iterator_match_p(const struct rpst_node * n,const struct rpst_iterator * it)360 rpst_iterator_match_p(const struct rpst_node *n, const struct rpst_iterator *it)
361 {
362
363 if (n->n_y > it->it_max_y) {
364 return false;
365 }
366 if (n->n_x < it->it_min_x) {
367 return false;
368 }
369 if (n->n_x > it->it_max_x) {
370 return false;
371 }
372 return true;
373 }
374
375 struct rpst_node *
rpst_iterate_first(struct rpst_tree * t,uint64_t max_y,uint64_t min_x,uint64_t max_x,struct rpst_iterator * it)376 rpst_iterate_first(struct rpst_tree *t, uint64_t max_y, uint64_t min_x,
377 uint64_t max_x, struct rpst_iterator *it)
378 {
379 struct rpst_node *n;
380
381 KASSERT(min_x <= max_x);
382 n = t->t_root;
383 if (n == NULL || n->n_y > max_y) {
384 return NULL;
385 }
386 if (rpst_height2max(t->t_height) < min_x) {
387 return NULL;
388 }
389 it->it_tree = t;
390 it->it_cur = n;
391 it->it_idx = (min_x & rpst_startmask(t)) != 0;
392 it->it_level = 0;
393 it->it_max_y = max_y;
394 it->it_min_x = min_x;
395 it->it_max_x = max_x;
396 return rpst_iterate_next(it);
397 }
398
399 static inline unsigned int
rpst_node_on_edge_p(const struct rpst_node * n,uint64_t val,uint64_t mask)400 rpst_node_on_edge_p(const struct rpst_node *n, uint64_t val, uint64_t mask)
401 {
402
403 return ((n->n_x ^ val) & ((-mask) << 1)) == 0;
404 }
405
406 static inline uint64_t
rpst_maxidx(const struct rpst_node * n,uint64_t max_x,uint64_t mask)407 rpst_maxidx(const struct rpst_node *n, uint64_t max_x, uint64_t mask)
408 {
409
410 if (rpst_node_on_edge_p(n, max_x, mask)) {
411 return (max_x & mask) != 0;
412 } else {
413 return 1;
414 }
415 }
416
417 static inline uint64_t
rpst_minidx(const struct rpst_node * n,uint64_t min_x,uint64_t mask)418 rpst_minidx(const struct rpst_node *n, uint64_t min_x, uint64_t mask)
419 {
420
421 if (rpst_node_on_edge_p(n, min_x, mask)) {
422 return (min_x & mask) != 0;
423 } else {
424 return 0;
425 }
426 }
427
428 struct rpst_node *
rpst_iterate_next(struct rpst_iterator * it)429 rpst_iterate_next(struct rpst_iterator *it)
430 {
431 struct rpst_tree *t;
432 struct rpst_node *n;
433 struct rpst_node *next;
434 const uint64_t max_y = it->it_max_y;
435 const uint64_t min_x = it->it_min_x;
436 const uint64_t max_x = it->it_max_x;
437 unsigned int idx;
438 unsigned int maxidx;
439 unsigned int level;
440 uint64_t mask;
441
442 t = it->it_tree;
443 n = it->it_cur;
444 idx = it->it_idx;
445 level = it->it_level;
446 mask = rpst_level2mask(t, level);
447 maxidx = rpst_maxidx(n, max_x, mask);
448 KASSERT(n == t->t_root || rpst_iterator_match_p(n, it));
449 next:
450 KASSERT(mask == rpst_level2mask(t, level));
451 KASSERT(idx >= rpst_minidx(n, min_x, mask));
452 KASSERT(maxidx == rpst_maxidx(n, max_x, mask));
453 KASSERT(idx <= maxidx + 2);
454 KASSERT(n != NULL);
455 #if 0
456 printf("%s: cur=%p, idx=%u maxidx=%u level=%u mask=%" PRIx64 "\n",
457 __func__, (void *)n, idx, maxidx, level, mask);
458 #endif
459 if (idx == maxidx + 1) { /* visit the current node */
460 idx++;
461 if (min_x <= n->n_x && n->n_x <= max_x) {
462 it->it_cur = n;
463 it->it_idx = idx;
464 it->it_level = level;
465 KASSERT(rpst_iterator_match_p(n, it));
466 return n; /* report */
467 }
468 goto next;
469 } else if (idx == maxidx + 2) { /* back to the parent */
470 struct rpst_node **where;
471
472 where = rpst_find_pptr(t, n, &next);
473 if (next == NULL) {
474 KASSERT(level == 0);
475 KASSERT(t->t_root == n);
476 KASSERT(&t->t_root == where);
477 return NULL; /* done */
478 }
479 KASSERT(level > 0);
480 level--;
481 n = next;
482 mask = rpst_level2mask(t, level);
483 maxidx = rpst_maxidx(n, max_x, mask);
484 idx = where - n->n_children + 1;
485 KASSERT(idx < 2 + 1);
486 goto next;
487 }
488 /* go to a child */
489 KASSERT(idx < 2);
490 next = n->n_children[idx];
491 if (next == NULL || next->n_y > max_y) {
492 idx++;
493 goto next;
494 }
495 KASSERT(next->n_parent == n);
496 KASSERT(next->n_y >= n->n_y);
497 level++;
498 mask >>= 1;
499 n = next;
500 idx = rpst_minidx(n, min_x, mask);
501 maxidx = rpst_maxidx(n, max_x, mask);
502 #if 0
503 printf("%s: visit %p idx=%u level=%u mask=%llx\n",
504 __func__, n, idx, level, mask);
505 #endif
506 goto next;
507 }
508
509 #if defined(UNITTEST)
510 #include <sys/time.h>
511
512 #include <inttypes.h>
513 #include <stdio.h>
514 #include <stdlib.h>
515
516 static void
rpst_dump_node(const struct rpst_node * n,unsigned int depth)517 rpst_dump_node(const struct rpst_node *n, unsigned int depth)
518 {
519 unsigned int i;
520
521 for (i = 0; i < depth; i++) {
522 printf(" ");
523 }
524 printf("[%u]", depth);
525 if (n == NULL) {
526 printf("NULL\n");
527 return;
528 }
529 printf("%p x=%" PRIx64 "(%" PRIu64 ") y=%" PRIx64 "(%" PRIu64 ")\n",
530 (const void *)n, n->n_x, n->n_x, n->n_y, n->n_y);
531 for (i = 0; i < 2; i++) {
532 rpst_dump_node(n->n_children[i], depth + 1);
533 }
534 }
535
536 static void
rpst_dump_tree(const struct rpst_tree * t)537 rpst_dump_tree(const struct rpst_tree *t)
538 {
539
540 printf("pst %p height=%u\n", (const void *)t, t->t_height);
541 rpst_dump_node(t->t_root, 0);
542 }
543
544 struct testnode {
545 struct rpst_node n;
546 struct testnode *next;
547 bool failed;
548 bool found;
549 };
550
551 struct rpst_tree t;
552 struct testnode *h = NULL;
553
554 static uintmax_t
tvdiff(const struct timeval * tv1,const struct timeval * tv2)555 tvdiff(const struct timeval *tv1, const struct timeval *tv2)
556 {
557
558 return (uintmax_t)tv1->tv_sec * 1000000 + tv1->tv_usec -
559 tv2->tv_sec * 1000000 - tv2->tv_usec;
560 }
561
562 static unsigned int
query(uint64_t max_y,uint64_t min_x,uint64_t max_x)563 query(uint64_t max_y, uint64_t min_x, uint64_t max_x)
564 {
565 struct testnode *n;
566 struct rpst_node *rn;
567 struct rpst_iterator it;
568 struct timeval start;
569 struct timeval end;
570 unsigned int done;
571
572 printf("querying max_y=%" PRIu64 " min_x=%" PRIu64 " max_x=%" PRIu64
573 "\n",
574 max_y, min_x, max_x);
575 done = 0;
576 gettimeofday(&start, NULL);
577 for (rn = rpst_iterate_first(&t, max_y, min_x, max_x, &it);
578 rn != NULL;
579 rn = rpst_iterate_next(&it)) {
580 done++;
581 #if 0
582 printf("found %p x=%" PRIu64 " y=%" PRIu64 "\n",
583 (void *)rn, rn->n_x, rn->n_y);
584 #endif
585 n = (void *)rn;
586 assert(!n->found);
587 n->found = true;
588 }
589 gettimeofday(&end, NULL);
590 printf("%u nodes found in %ju usecs\n", done,
591 tvdiff(&end, &start));
592
593 gettimeofday(&start, NULL);
594 for (n = h; n != NULL; n = n->next) {
595 assert(n->failed ||
596 n->found == rpst_iterator_match_p(&n->n, &it));
597 n->found = false;
598 }
599 gettimeofday(&end, NULL);
600 printf("(linear search took %ju usecs)\n", tvdiff(&end, &start));
601 return done;
602 }
603
604 int
main(int argc,char * argv[])605 main(int argc, char *argv[])
606 {
607 struct testnode *n;
608 unsigned int i;
609 struct rpst_iterator it;
610 struct timeval start;
611 struct timeval end;
612 uint64_t min_y = UINT64_MAX;
613 uint64_t max_y = 0;
614 uint64_t min_x = UINT64_MAX;
615 uint64_t max_x = 0;
616 uint64_t w;
617 unsigned int done;
618 unsigned int fail;
619 unsigned int num = 500000;
620
621 rpst_init_tree(&t);
622 rpst_dump_tree(&t);
623 assert(NULL == rpst_iterate_first(&t, UINT64_MAX, 0, UINT64_MAX, &it));
624
625 for (i = 0; i < num; i++) {
626 n = malloc(sizeof(*n));
627 if (i > 499000) {
628 n->n.n_x = 10;
629 n->n.n_y = random();
630 } else if (i > 400000) {
631 n->n.n_x = i;
632 n->n.n_y = random();
633 } else {
634 n->n.n_x = random();
635 n->n.n_y = random();
636 }
637 if (n->n.n_y < min_y) {
638 min_y = n->n.n_y;
639 }
640 if (n->n.n_y > max_y) {
641 max_y = n->n.n_y;
642 }
643 if (n->n.n_x < min_x) {
644 min_x = n->n.n_x;
645 }
646 if (n->n.n_x > max_x) {
647 max_x = n->n.n_x;
648 }
649 n->found = false;
650 n->failed = false;
651 n->next = h;
652 h = n;
653 }
654
655 done = 0;
656 fail = 0;
657 gettimeofday(&start, NULL);
658 for (n = h; n != NULL; n = n->next) {
659 struct rpst_node *o;
660 #if 0
661 printf("insert %p x=%" PRIu64 " y=%" PRIu64 "\n",
662 n, n->n.n_x, n->n.n_y);
663 #endif
664 o = rpst_insert_node(&t, &n->n);
665 if (o == NULL) {
666 done++;
667 } else {
668 n->failed = true;
669 fail++;
670 }
671 }
672 gettimeofday(&end, NULL);
673 printf("%u nodes inserted and %u insertion failed in %ju usecs\n",
674 done, fail,
675 tvdiff(&end, &start));
676
677 assert(min_y == 0 || 0 == query(min_y - 1, 0, UINT64_MAX));
678 assert(max_x == UINT64_MAX ||
679 0 == query(UINT64_MAX, max_x + 1, UINT64_MAX));
680 assert(min_x == 0 || 0 == query(UINT64_MAX, 0, min_x - 1));
681
682 done = query(max_y, min_x, max_x);
683 assert(done == num - fail);
684
685 done = query(UINT64_MAX, 0, UINT64_MAX);
686 assert(done == num - fail);
687
688 w = max_x - min_x;
689 query(max_y / 2, min_x, max_x);
690 query(max_y, min_x + w / 2, max_x);
691 query(max_y / 2, min_x + w / 2, max_x);
692 query(max_y / 2, min_x, max_x - w / 2);
693 query(max_y / 2, min_x + w / 3, max_x - w / 3);
694 query(max_y - 1, min_x + 1, max_x - 1);
695 query(UINT64_MAX, 10, 10);
696
697 done = 0;
698 gettimeofday(&start, NULL);
699 for (n = h; n != NULL; n = n->next) {
700 if (n->failed) {
701 continue;
702 }
703 #if 0
704 printf("remove %p x=%" PRIu64 " y=%" PRIu64 "\n",
705 n, n->n.n_x, n->n.n_y);
706 #endif
707 rpst_remove_node(&t, &n->n);
708 done++;
709 }
710 gettimeofday(&end, NULL);
711 printf("%u nodes removed in %ju usecs\n", done,
712 tvdiff(&end, &start));
713
714 rpst_dump_tree(&t);
715 }
716 #endif /* defined(UNITTEST) */
717