xref: /netbsd-src/crypto/external/bsd/openssh/dist/cipher-ctr-mt.c (revision b1c86f5f087524e68db12794ee9c3e3da1ab17a0)
1 /*
2  * OpenSSH Multi-threaded AES-CTR Cipher
3  *
4  * Author: Benjamin Bennett <ben@psc.edu>
5  * Copyright (c) 2008 Pittsburgh Supercomputing Center. All rights reserved.
6  *
7  * Based on original OpenSSH AES-CTR cipher. Small portions remain unchanged,
8  * Copyright (c) 2003 Markus Friedl <markus@openbsd.org>
9  *
10  * Permission to use, copy, modify, and distribute this software for any
11  * purpose with or without fee is hereby granted, provided that the above
12  * copyright notice and this permission notice appear in all copies.
13  *
14  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
15  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
16  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
17  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
18  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
19  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
20  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
21  */
22 #include "includes.h"
23 
24 #include <sys/types.h>
25 
26 #include <stdarg.h>
27 #include <string.h>
28 
29 #include <openssl/evp.h>
30 
31 #include "xmalloc.h"
32 #include "log.h"
33 
34 #ifndef USE_BUILTIN_RIJNDAEL
35 #include <openssl/aes.h>
36 #endif
37 
38 #include <pthread.h>
39 
40 /*-------------------- TUNABLES --------------------*/
41 /* Number of pregen threads to use */
42 #define CIPHER_THREADS	2
43 
44 /* Number of keystream queues */
45 #define NUMKQ		(CIPHER_THREADS + 2)
46 
47 /* Length of a keystream queue */
48 #define KQLEN		4096
49 
50 /* Processor cacheline length */
51 #define CACHELINE_LEN	64
52 
53 /* Collect thread stats and print at cancellation when in debug mode */
54 /* #define CIPHER_THREAD_STATS */
55 
56 /* Use single-byte XOR instead of 8-byte XOR */
57 /* #define CIPHER_BYTE_XOR */
58 /*-------------------- END TUNABLES --------------------*/
59 
60 #ifdef AES_CTR_MT
61 
62 
63 const EVP_CIPHER *evp_aes_ctr_mt(void);
64 
65 #ifdef CIPHER_THREAD_STATS
66 /*
67  * Struct to collect thread stats
68  */
69 struct thread_stats {
70 	u_int	fills;
71 	u_int	skips;
72 	u_int	waits;
73 	u_int	drains;
74 };
75 
76 /*
77  * Debug print the thread stats
78  * Use with pthread_cleanup_push for displaying at thread cancellation
79  */
80 static void
81 thread_loop_stats(void *x)
82 {
83 	struct thread_stats *s = x;
84 
85 	debug("tid %lu - %u fills, %u skips, %u waits", pthread_self(),
86 			s->fills, s->skips, s->waits);
87 }
88 
89  #define STATS_STRUCT(s)	struct thread_stats s;
90  #define STATS_INIT(s)		memset(&s, 0, sizeof(s))
91  #define STATS_FILL(s)		s.fills++
92  #define STATS_SKIP(s)		s.skips++
93  #define STATS_WAIT(s)		s.waits++
94  #define STATS_DRAIN(s)		s.drains++
95 #else
96  #define STATS_STRUCT(s)
97  #define STATS_INIT(s)
98  #define STATS_FILL(s)
99  #define STATS_SKIP(s)
100  #define STATS_WAIT(s)
101  #define STATS_DRAIN(s)
102 #endif
103 
104 /* Keystream Queue state */
105 enum {
106 	KQINIT,
107 	KQEMPTY,
108 	KQFILLING,
109 	KQFULL,
110 	KQDRAINING
111 };
112 
113 /* Keystream Queue struct */
114 struct kq {
115 	u_char		keys[KQLEN][AES_BLOCK_SIZE];
116 	u_char		ctr[AES_BLOCK_SIZE];
117 	u_char		pad0[CACHELINE_LEN];
118 	volatile int	qstate;
119 	pthread_mutex_t	lock;
120 	pthread_cond_t	cond;
121 	u_char		pad1[CACHELINE_LEN];
122 };
123 
124 /* Context struct */
125 struct ssh_aes_ctr_ctx
126 {
127 	struct kq	q[NUMKQ];
128 	AES_KEY		aes_ctx;
129 	STATS_STRUCT(stats)
130 	u_char		aes_counter[AES_BLOCK_SIZE];
131 	pthread_t	tid[CIPHER_THREADS];
132 	int		state;
133 	int		qidx;
134 	int		ridx;
135 };
136 
137 /* <friedl>
138  * increment counter 'ctr',
139  * the counter is of size 'len' bytes and stored in network-byte-order.
140  * (LSB at ctr[len-1], MSB at ctr[0])
141  */
142 static void
143 ssh_ctr_inc(u_char *ctr, u_int len)
144 {
145 	int i;
146 
147 	for (i = len - 1; i >= 0; i--)
148 		if (++ctr[i])	/* continue on overflow */
149 			return;
150 }
151 
152 /*
153  * Add num to counter 'ctr'
154  */
155 static void
156 ssh_ctr_add(u_char *ctr, uint32_t num, u_int len)
157 {
158 	int i;
159 	uint16_t n;
160 
161 	for (n = 0, i = len - 1; i >= 0 && (num || n); i--) {
162 		n = ctr[i] + (num & 0xff) + n;
163 		num >>= 8;
164 		ctr[i] = n & 0xff;
165 		n >>= 8;
166 	}
167 }
168 
169 /*
170  * Threads may be cancelled in a pthread_cond_wait, we must free the mutex
171  */
172 static void
173 thread_loop_cleanup(void *x)
174 {
175 	pthread_mutex_unlock((pthread_mutex_t *)x);
176 }
177 
178 /*
179  * The life of a pregen thread:
180  *    Find empty keystream queues and fill them using their counter.
181  *    When done, update counter for the next fill.
182  */
183 static void *
184 thread_loop(void *x)
185 {
186 	AES_KEY key;
187 	STATS_STRUCT(stats)
188 	struct ssh_aes_ctr_ctx *c = x;
189 	struct kq *q;
190 	int i;
191 	int qidx;
192 
193 	/* Threads stats on cancellation */
194 	STATS_INIT(stats);
195 #ifdef CIPHER_THREAD_STATS
196 	pthread_cleanup_push(thread_loop_stats, &stats);
197 #endif
198 
199 	/* Thread local copy of AES key */
200 	memcpy(&key, &c->aes_ctx, sizeof(key));
201 
202 	/*
203 	 * Handle the special case of startup, one thread must fill
204  	 * the first KQ then mark it as draining. Lock held throughout.
205  	 */
206 	if (pthread_equal(pthread_self(), c->tid[0])) {
207 		q = &c->q[0];
208 		pthread_mutex_lock(&q->lock);
209 		if (q->qstate == KQINIT) {
210 			for (i = 0; i < KQLEN; i++) {
211 				AES_encrypt(q->ctr, q->keys[i], &key);
212 				ssh_ctr_inc(q->ctr, AES_BLOCK_SIZE);
213 			}
214 			ssh_ctr_add(q->ctr, KQLEN * (NUMKQ - 1), AES_BLOCK_SIZE);
215 			q->qstate = KQDRAINING;
216 			STATS_FILL(stats);
217 			pthread_cond_broadcast(&q->cond);
218 		}
219 		pthread_mutex_unlock(&q->lock);
220 	}
221 	else
222 		STATS_SKIP(stats);
223 
224 	/*
225  	 * Normal case is to find empty queues and fill them, skipping over
226  	 * queues already filled by other threads and stopping to wait for
227  	 * a draining queue to become empty.
228  	 *
229  	 * Multiple threads may be waiting on a draining queue and awoken
230  	 * when empty.  The first thread to wake will mark it as filling,
231  	 * others will move on to fill, skip, or wait on the next queue.
232  	 */
233 	for (qidx = 1;; qidx = (qidx + 1) % NUMKQ) {
234 		/* Check if I was cancelled, also checked in cond_wait */
235 		pthread_testcancel();
236 
237 		/* Lock queue and block if its draining */
238 		q = &c->q[qidx];
239 		pthread_mutex_lock(&q->lock);
240 		pthread_cleanup_push(thread_loop_cleanup, &q->lock);
241 		while (q->qstate == KQDRAINING || q->qstate == KQINIT) {
242 			STATS_WAIT(stats);
243 			pthread_cond_wait(&q->cond, &q->lock);
244 		}
245 		pthread_cleanup_pop(0);
246 
247 		/* If filling or full, somebody else got it, skip */
248 		if (q->qstate != KQEMPTY) {
249 			pthread_mutex_unlock(&q->lock);
250 			STATS_SKIP(stats);
251 			continue;
252 		}
253 
254 		/*
255  		 * Empty, let's fill it.
256  		 * Queue lock is relinquished while we do this so others
257  		 * can see that it's being filled.
258  		 */
259 		q->qstate = KQFILLING;
260 		pthread_mutex_unlock(&q->lock);
261 		for (i = 0; i < KQLEN; i++) {
262 			AES_encrypt(q->ctr, q->keys[i], &key);
263 			ssh_ctr_inc(q->ctr, AES_BLOCK_SIZE);
264 		}
265 
266 		/* Re-lock, mark full and signal consumer */
267 		pthread_mutex_lock(&q->lock);
268 		ssh_ctr_add(q->ctr, KQLEN * (NUMKQ - 1), AES_BLOCK_SIZE);
269 		q->qstate = KQFULL;
270 		STATS_FILL(stats);
271 		pthread_cond_signal(&q->cond);
272 		pthread_mutex_unlock(&q->lock);
273 	}
274 
275 #ifdef CIPHER_THREAD_STATS
276 	/* Stats */
277 	pthread_cleanup_pop(1);
278 #endif
279 
280 	return NULL;
281 }
282 
283 static int
284 ssh_aes_ctr(EVP_CIPHER_CTX *ctx, u_char *dest, const u_char *src,
285     u_int len)
286 {
287 	struct ssh_aes_ctr_ctx *c;
288 	struct kq *q, *oldq;
289 	int ridx;
290 	u_char *buf;
291 
292 	if (len == 0)
293 		return (1);
294 	if ((c = EVP_CIPHER_CTX_get_app_data(ctx)) == NULL)
295 		return (0);
296 
297 	q = &c->q[c->qidx];
298 	ridx = c->ridx;
299 
300 	/* src already padded to block multiple */
301 	while (len > 0) {
302 		buf = q->keys[ridx];
303 
304 #ifdef CIPHER_BYTE_XOR
305 		dest[0] = src[0] ^ buf[0];
306 		dest[1] = src[1] ^ buf[1];
307 		dest[2] = src[2] ^ buf[2];
308 		dest[3] = src[3] ^ buf[3];
309 		dest[4] = src[4] ^ buf[4];
310 		dest[5] = src[5] ^ buf[5];
311 		dest[6] = src[6] ^ buf[6];
312 		dest[7] = src[7] ^ buf[7];
313 		dest[8] = src[8] ^ buf[8];
314 		dest[9] = src[9] ^ buf[9];
315 		dest[10] = src[10] ^ buf[10];
316 		dest[11] = src[11] ^ buf[11];
317 		dest[12] = src[12] ^ buf[12];
318 		dest[13] = src[13] ^ buf[13];
319 		dest[14] = src[14] ^ buf[14];
320 		dest[15] = src[15] ^ buf[15];
321 #else
322 		*(uint64_t *)dest = *(uint64_t *)src ^ *(uint64_t *)buf;
323 		*(uint64_t *)(dest + 8) = *(uint64_t *)(src + 8) ^
324 						*(uint64_t *)(buf + 8);
325 #endif
326 
327 		dest += 16;
328 		src += 16;
329 		len -= 16;
330 		ssh_ctr_inc(ctx->iv, AES_BLOCK_SIZE);
331 
332 		/* Increment read index, switch queues on rollover */
333 		if ((ridx = (ridx + 1) % KQLEN) == 0) {
334 			oldq = q;
335 
336 			/* Mark next queue draining, may need to wait */
337 			c->qidx = (c->qidx + 1) % NUMKQ;
338 			q = &c->q[c->qidx];
339 			pthread_mutex_lock(&q->lock);
340 			while (q->qstate != KQFULL) {
341 				STATS_WAIT(c->stats);
342 				pthread_cond_wait(&q->cond, &q->lock);
343 			}
344 			q->qstate = KQDRAINING;
345 			pthread_mutex_unlock(&q->lock);
346 
347 			/* Mark consumed queue empty and signal producers */
348 			pthread_mutex_lock(&oldq->lock);
349 			oldq->qstate = KQEMPTY;
350 			STATS_DRAIN(c->stats);
351 			pthread_cond_broadcast(&oldq->cond);
352 			pthread_mutex_unlock(&oldq->lock);
353 		}
354 	}
355 	c->ridx = ridx;
356 	return (1);
357 }
358 
359 #define HAVE_NONE       0
360 #define HAVE_KEY        1
361 #define HAVE_IV         2
362 static int
363 ssh_aes_ctr_init(EVP_CIPHER_CTX *ctx, const u_char *key, const u_char *iv,
364     int enc)
365 {
366 	struct ssh_aes_ctr_ctx *c;
367 	int i;
368 
369 	if ((c = EVP_CIPHER_CTX_get_app_data(ctx)) == NULL) {
370 		c = xmalloc(sizeof(*c));
371 
372 		c->state = HAVE_NONE;
373 		for (i = 0; i < NUMKQ; i++) {
374 			pthread_mutex_init(&c->q[i].lock, NULL);
375 			pthread_cond_init(&c->q[i].cond, NULL);
376 		}
377 
378 		STATS_INIT(c->stats);
379 
380 		EVP_CIPHER_CTX_set_app_data(ctx, c);
381 	}
382 
383 	if (c->state == (HAVE_KEY | HAVE_IV)) {
384 		/* Cancel pregen threads */
385 		for (i = 0; i < CIPHER_THREADS; i++)
386 			pthread_cancel(c->tid[i]);
387 		for (i = 0; i < CIPHER_THREADS; i++)
388 			pthread_join(c->tid[i], NULL);
389 		/* Start over getting key & iv */
390 		c->state = HAVE_NONE;
391 	}
392 
393 	if (key != NULL) {
394 		AES_set_encrypt_key(key, EVP_CIPHER_CTX_key_length(ctx) * 8,
395 		    &c->aes_ctx);
396 		c->state |= HAVE_KEY;
397 	}
398 
399 	if (iv != NULL) {
400 		memcpy(ctx->iv, iv, AES_BLOCK_SIZE);
401 		c->state |= HAVE_IV;
402 	}
403 
404 	if (c->state == (HAVE_KEY | HAVE_IV)) {
405 		/* Clear queues */
406 		memcpy(c->q[0].ctr, ctx->iv, AES_BLOCK_SIZE);
407 		c->q[0].qstate = KQINIT;
408 		for (i = 1; i < NUMKQ; i++) {
409 			memcpy(c->q[i].ctr, ctx->iv, AES_BLOCK_SIZE);
410 			ssh_ctr_add(c->q[i].ctr, i * KQLEN, AES_BLOCK_SIZE);
411 			c->q[i].qstate = KQEMPTY;
412 		}
413 		c->qidx = 0;
414 		c->ridx = 0;
415 
416 		/* Start threads */
417 		for (i = 0; i < CIPHER_THREADS; i++) {
418 			pthread_create(&c->tid[i], NULL, thread_loop, c);
419 		}
420 		pthread_mutex_lock(&c->q[0].lock);
421 		while (c->q[0].qstate != KQDRAINING)
422 			pthread_cond_wait(&c->q[0].cond, &c->q[0].lock);
423 		pthread_mutex_unlock(&c->q[0].lock);
424 
425 	}
426 	return (1);
427 }
428 
429 static int
430 ssh_aes_ctr_cleanup(EVP_CIPHER_CTX *ctx)
431 {
432 	struct ssh_aes_ctr_ctx *c;
433 	int i;
434 
435 	if ((c = EVP_CIPHER_CTX_get_app_data(ctx)) != NULL) {
436 #ifdef CIPHER_THREAD_STATS
437 		debug("main thread: %u drains, %u waits", c->stats.drains,
438 				c->stats.waits);
439 #endif
440 		/* Cancel pregen threads */
441 		for (i = 0; i < CIPHER_THREADS; i++)
442 			pthread_cancel(c->tid[i]);
443 		for (i = 0; i < CIPHER_THREADS; i++)
444 			pthread_join(c->tid[i], NULL);
445 
446 		memset(c, 0, sizeof(*c));
447 		xfree(c);
448 		EVP_CIPHER_CTX_set_app_data(ctx, NULL);
449 	}
450 	return (1);
451 }
452 
453 /* <friedl> */
454 const EVP_CIPHER *
455 evp_aes_ctr_mt(void)
456 {
457 	static EVP_CIPHER aes_ctr;
458 
459 	memset(&aes_ctr, 0, sizeof(EVP_CIPHER));
460 	aes_ctr.nid = NID_undef;
461 	aes_ctr.block_size = AES_BLOCK_SIZE;
462 	aes_ctr.iv_len = AES_BLOCK_SIZE;
463 	aes_ctr.key_len = 16;
464 	aes_ctr.init = ssh_aes_ctr_init;
465 	aes_ctr.cleanup = ssh_aes_ctr_cleanup;
466 	aes_ctr.do_cipher = ssh_aes_ctr;
467 #ifndef SSH_OLD_EVP
468 	aes_ctr.flags = EVP_CIPH_CBC_MODE | EVP_CIPH_VARIABLE_LENGTH |
469 	    EVP_CIPH_ALWAYS_CALL_INIT | EVP_CIPH_CUSTOM_IV;
470 #endif
471 	return (&aes_ctr);
472 }
473 #endif
474