xref: /spdk/test/common/lib/ut_multithread.c (revision 8a0a98d35e21f282088edf28b9e8da66ec390e3a)
1 /*-
2  *   BSD LICENSE
3  *
4  *   Copyright (c) Intel Corporation.
5  *   All rights reserved.
6  *
7  *   Redistribution and use in source and binary forms, with or without
8  *   modification, are permitted provided that the following conditions
9  *   are met:
10  *
11  *     * Redistributions of source code must retain the above copyright
12  *       notice, this list of conditions and the following disclaimer.
13  *     * Redistributions in binary form must reproduce the above copyright
14  *       notice, this list of conditions and the following disclaimer in
15  *       the documentation and/or other materials provided with the
16  *       distribution.
17  *     * Neither the name of Intel Corporation nor the names of its
18  *       contributors may be used to endorse or promote products derived
19  *       from this software without specific prior written permission.
20  *
21  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32  */
33 
34 #include "spdk_cunit.h"
35 #include "spdk/thread.h"
36 #include "spdk_internal/mock.h"
37 
38 static uint32_t g_ut_num_threads;
39 static uint64_t g_current_time_in_us = 0;
40 
41 int allocate_threads(int num_threads);
42 void free_threads(void);
43 void poll_threads(void);
44 int poll_thread(uintptr_t thread_id);
45 void increment_time(uint64_t time_in_us);
46 void reset_time(void);
47 
48 struct ut_msg {
49 	spdk_thread_fn		fn;
50 	void			*ctx;
51 	TAILQ_ENTRY(ut_msg)	link;
52 };
53 
54 struct ut_thread {
55 	struct spdk_thread	*thread;
56 	struct spdk_io_channel	*ch;
57 	TAILQ_HEAD(, ut_msg)	msgs;
58 	TAILQ_HEAD(, ut_poller)	pollers;
59 };
60 
61 struct ut_thread *g_ut_threads;
62 
63 struct ut_poller {
64 	spdk_poller_fn		fn;
65 	void			*arg;
66 	TAILQ_ENTRY(ut_poller)	tailq;
67 	uint64_t		period_us;
68 	uint64_t		next_expiration_in_us;
69 };
70 
71 static void
72 __send_msg(spdk_thread_fn fn, void *ctx, void *thread_ctx)
73 {
74 	struct ut_thread *thread = thread_ctx;
75 	struct ut_msg *msg;
76 
77 	msg = calloc(1, sizeof(*msg));
78 	SPDK_CU_ASSERT_FATAL(msg != NULL);
79 
80 	msg->fn = fn;
81 	msg->ctx = ctx;
82 	TAILQ_INSERT_TAIL(&thread->msgs, msg, link);
83 }
84 
85 static struct spdk_poller *
86 __start_poller(void *thread_ctx, spdk_poller_fn fn, void *arg, uint64_t period_microseconds)
87 {
88 	struct ut_thread *thread = thread_ctx;
89 	struct ut_poller *poller = calloc(1, sizeof(struct ut_poller));
90 
91 	SPDK_CU_ASSERT_FATAL(poller != NULL);
92 
93 	poller->fn = fn;
94 	poller->arg = arg;
95 	poller->period_us = period_microseconds;
96 	poller->next_expiration_in_us = g_current_time_in_us + poller->period_us;
97 
98 	TAILQ_INSERT_TAIL(&thread->pollers, poller, tailq);
99 
100 	return (struct spdk_poller *)poller;
101 }
102 
103 static void
104 __stop_poller(struct spdk_poller *poller, void *thread_ctx)
105 {
106 	struct ut_thread *thread = thread_ctx;
107 
108 	TAILQ_REMOVE(&thread->pollers, (struct ut_poller *)poller, tailq);
109 
110 	free(poller);
111 }
112 
113 static uintptr_t g_thread_id = MOCK_PASS_THRU;
114 
115 static void
116 set_thread(uintptr_t thread_id)
117 {
118 	g_thread_id = thread_id;
119 	MOCK_SET(pthread_self, pthread_t, (pthread_t)thread_id);
120 }
121 
122 int
123 allocate_threads(int num_threads)
124 {
125 	struct spdk_thread *thread;
126 	uint32_t i;
127 
128 	g_ut_num_threads = num_threads;
129 
130 	g_ut_threads = calloc(num_threads, sizeof(*g_ut_threads));
131 	SPDK_CU_ASSERT_FATAL(g_ut_threads != NULL);
132 
133 	for (i = 0; i < g_ut_num_threads; i++) {
134 		set_thread(i);
135 		spdk_allocate_thread(__send_msg, __start_poller, __stop_poller,
136 				     &g_ut_threads[i], NULL);
137 		thread = spdk_get_thread();
138 		SPDK_CU_ASSERT_FATAL(thread != NULL);
139 		g_ut_threads[i].thread = thread;
140 		TAILQ_INIT(&g_ut_threads[i].msgs);
141 		TAILQ_INIT(&g_ut_threads[i].pollers);
142 	}
143 
144 	set_thread(MOCK_PASS_THRU);
145 	return 0;
146 }
147 
148 void
149 free_threads(void)
150 {
151 	uint32_t i;
152 
153 	for (i = 0; i < g_ut_num_threads; i++) {
154 		set_thread(i);
155 		spdk_free_thread();
156 	}
157 
158 	g_ut_num_threads = 0;
159 	free(g_ut_threads);
160 	g_ut_threads = NULL;
161 }
162 
163 void
164 increment_time(uint64_t time_in_us)
165 {
166 	g_current_time_in_us += time_in_us;
167 }
168 
169 static void
170 reset_pollers(void)
171 {
172 	uint32_t		i = 0;
173 	struct ut_thread	*thread = NULL;
174 	struct ut_poller	*poller = NULL;
175 	uintptr_t		original_thread_id = g_thread_id;
176 
177 	CU_ASSERT(g_current_time_in_us == 0);
178 
179 	for (i = 0; i < g_ut_num_threads; i++) {
180 		set_thread(i);
181 		thread = &g_ut_threads[i];
182 
183 		TAILQ_FOREACH(poller, &thread->pollers, tailq) {
184 			poller->next_expiration_in_us = g_current_time_in_us + poller->period_us;
185 		}
186 	}
187 
188 	set_thread(original_thread_id);
189 }
190 
191 void
192 reset_time(void)
193 {
194 	g_current_time_in_us = 0;
195 	reset_pollers();
196 }
197 
198 int
199 poll_thread(uintptr_t thread_id)
200 {
201 	int count = 0;
202 	struct ut_thread *thread = &g_ut_threads[thread_id];
203 	struct ut_msg *msg;
204 	struct ut_poller *poller;
205 	uintptr_t original_thread_id;
206 	TAILQ_HEAD(, ut_poller)	tmp_pollers;
207 
208 	CU_ASSERT(thread_id != (uintptr_t)MOCK_PASS_THRU);
209 	CU_ASSERT(thread_id < g_ut_num_threads);
210 
211 	original_thread_id = g_thread_id;
212 	set_thread(thread_id);
213 
214 	while (!TAILQ_EMPTY(&thread->msgs)) {
215 		msg = TAILQ_FIRST(&thread->msgs);
216 		TAILQ_REMOVE(&thread->msgs, msg, link);
217 
218 		msg->fn(msg->ctx);
219 		count++;
220 		free(msg);
221 	}
222 
223 	TAILQ_INIT(&tmp_pollers);
224 
225 	while (!TAILQ_EMPTY(&thread->pollers)) {
226 		poller = TAILQ_FIRST(&thread->pollers);
227 		TAILQ_REMOVE(&thread->pollers, poller, tailq);
228 
229 		while (g_current_time_in_us >= poller->next_expiration_in_us) {
230 			if (poller->fn) {
231 				poller->fn(poller->arg);
232 			}
233 
234 			if (poller->period_us == 0) {
235 				break;
236 			} else {
237 				poller->next_expiration_in_us += poller->period_us;
238 			}
239 		}
240 
241 		TAILQ_INSERT_TAIL(&tmp_pollers, poller, tailq);
242 	}
243 
244 	TAILQ_SWAP(&tmp_pollers, &thread->pollers, ut_poller, tailq);
245 
246 	set_thread(original_thread_id);
247 
248 	return count;
249 }
250 
251 void
252 poll_threads(void)
253 {
254 	bool msg_processed;
255 	uint32_t i, count;
256 
257 	while (true) {
258 		msg_processed = false;
259 
260 		for (i = 0; i < g_ut_num_threads; i++) {
261 			count = poll_thread(i);
262 			if (count > 0) {
263 				msg_processed = true;
264 			}
265 		}
266 
267 		if (!msg_processed) {
268 			break;
269 		}
270 	}
271 }
272