xref: /openbsd-src/usr.bin/ssh/sntrup761.c (revision f6246b7f478ea7b2b6df549ae5998f8112d22650)
1 /*  $OpenBSD: sntrup761.c,v 1.2 2020/12/30 14:13:28 tobhe Exp $ */
2 
3 /*
4  * Public Domain, Authors:
5  * - Daniel J. Bernstein
6  * - Chitchanok Chuengsatiansup
7  * - Tanja Lange
8  * - Christine van Vredendaal
9  */
10 
11 #include <string.h>
12 #include "crypto_api.h"
13 #include "int32_minmax.inc"
14 
15 #define CRYPTO_NAMESPACE(s) s
16 
17 /* from supercop-20201130/crypto_sort/int32/portable4/sort.c */
18 #define int32 crypto_int32
19 
20 
21 static void crypto_sort_int32(void *array,long long n)
22 {
23   long long top,p,q,r,i,j;
24   int32 *x = array;
25 
26   if (n < 2) return;
27   top = 1;
28   while (top < n - top) top += top;
29 
30   for (p = top;p >= 1;p >>= 1) {
31     i = 0;
32     while (i + 2 * p <= n) {
33       for (j = i;j < i + p;++j)
34         int32_MINMAX(x[j],x[j+p]);
35       i += 2 * p;
36     }
37     for (j = i;j < n - p;++j)
38       int32_MINMAX(x[j],x[j+p]);
39 
40     i = 0;
41     j = 0;
42     for (q = top;q > p;q >>= 1) {
43       if (j != i) for (;;) {
44         if (j == n - q) goto done;
45         int32 a = x[j + p];
46         for (r = q;r > p;r >>= 1)
47           int32_MINMAX(a,x[j + r]);
48         x[j + p] = a;
49         ++j;
50         if (j == i + p) {
51           i += 2 * p;
52           break;
53         }
54       }
55       while (i + p <= n - q) {
56         for (j = i;j < i + p;++j) {
57           int32 a = x[j + p];
58           for (r = q;r > p;r >>= 1)
59             int32_MINMAX(a,x[j+r]);
60           x[j + p] = a;
61         }
62         i += 2 * p;
63       }
64       /* now i + p > n - q */
65       j = i;
66       while (j < n - q) {
67         int32 a = x[j + p];
68         for (r = q;r > p;r >>= 1)
69           int32_MINMAX(a,x[j+r]);
70         x[j + p] = a;
71         ++j;
72       }
73 
74       done: ;
75     }
76   }
77 }
78 
79 /* from supercop-20201130/crypto_sort/uint32/useint32/sort.c */
80 
81 /* can save time by vectorizing xor loops */
82 /* can save time by integrating xor loops with int32_sort */
83 
84 static void crypto_sort_uint32(void *array,long long n)
85 {
86   crypto_uint32 *x = array;
87   long long j;
88   for (j = 0;j < n;++j) x[j] ^= 0x80000000;
89   crypto_sort_int32(array,n);
90   for (j = 0;j < n;++j) x[j] ^= 0x80000000;
91 }
92 
93 /* from supercop-20201130/crypto_kem/sntrup761/ref/uint64.h */
94 #ifndef UINT64_H
95 #define UINT64_H
96 
97 
98 typedef uint64_t uint64;
99 
100 #endif
101 
102 /* from supercop-20201130/crypto_kem/sntrup761/ref/uint16.h */
103 #ifndef UINT16_H
104 #define UINT16_H
105 
106 typedef uint16_t uint16;
107 
108 #endif
109 
110 /* from supercop-20201130/crypto_kem/sntrup761/ref/uint32.h */
111 #ifndef UINT32_H
112 #define UINT32_H
113 
114 #define uint32_div_uint14 CRYPTO_NAMESPACE(uint32_div_uint14)
115 #define uint32_mod_uint14 CRYPTO_NAMESPACE(uint32_mod_uint14)
116 #define uint32_divmod_uint14 CRYPTO_NAMESPACE(uint32_divmod_uint14)
117 
118 
119 typedef uint32_t uint32;
120 
121 /*
122 assuming 1 <= m < 16384:
123 q = uint32_div_uint14(x,m) means q = x/m
124 r = uint32_mod_uint14(x,m) means r = x/m
125 uint32_moddiv_uint14(&q,&r,x,m) means q = x/m, r = x%m
126 */
127 
128 extern uint32 uint32_div_uint14(uint32,uint16);
129 extern uint16 uint32_mod_uint14(uint32,uint16);
130 static void uint32_divmod_uint14(uint32 *,uint16 *,uint32,uint16);
131 
132 #endif
133 
134 /* from supercop-20201130/crypto_kem/sntrup761/ref/int8.h */
135 #ifndef INT8_H
136 #define INT8_H
137 
138 typedef int8_t int8;
139 
140 #endif
141 
142 /* from supercop-20201130/crypto_kem/sntrup761/ref/int16.h */
143 #ifndef INT16_H
144 #define INT16_H
145 
146 typedef int16_t int16;
147 
148 #endif
149 
150 /* from supercop-20201130/crypto_kem/sntrup761/ref/int32.h */
151 #ifndef INT32_H
152 #define INT32_H
153 
154 #define int32_div_uint14 CRYPTO_NAMESPACE(int32_div_uint14)
155 #define int32_mod_uint14 CRYPTO_NAMESPACE(int32_mod_uint14)
156 #define int32_divmod_uint14 CRYPTO_NAMESPACE(int32_divmod_uint14)
157 
158 
159 typedef int32_t int32;
160 
161 /*
162 assuming 1 <= m < 16384:
163 q = int32_div_uint14(x,m) means q = x/m
164 r = int32_mod_uint14(x,m) means r = x/m
165 int32_moddiv_uint14(&q,&r,x,m) means q = x/m, r = x%m
166 */
167 
168 extern int32 int32_div_uint14(int32,uint16);
169 extern uint16 int32_mod_uint14(int32,uint16);
170 static void int32_divmod_uint14(int32 *,uint16 *,int32,uint16);
171 
172 #endif
173 
174 /* from supercop-20201130/crypto_kem/sntrup761/ref/uint32.c */
175 
176 /*
177 CPU division instruction typically takes time depending on x.
178 This software is designed to take time independent of x.
179 Time still varies depending on m; user must ensure that m is constant.
180 Time also varies on CPUs where multiplication is variable-time.
181 There could be more CPU issues.
182 There could also be compiler issues.
183 */
184 
185 static void uint32_divmod_uint14(uint32 *q,uint16 *r,uint32 x,uint16 m)
186 {
187   uint32 v = 0x80000000;
188   uint32 qpart;
189   uint32 mask;
190 
191   v /= m;
192 
193   /* caller guarantees m > 0 */
194   /* caller guarantees m < 16384 */
195   /* vm <= 2^31 <= vm+m-1 */
196   /* xvm <= 2^31 x <= xvm+x(m-1) */
197 
198   *q = 0;
199 
200   qpart = (x*(uint64)v)>>31;
201   /* 2^31 qpart <= xv <= 2^31 qpart + 2^31-1 */
202   /* 2^31 qpart m <= xvm <= 2^31 qpart m + (2^31-1)m */
203   /* 2^31 qpart m <= 2^31 x <= 2^31 qpart m + (2^31-1)m + x(m-1) */
204   /* 0 <= 2^31 newx <= (2^31-1)m + x(m-1) */
205   /* 0 <= newx <= (1-1/2^31)m + x(m-1)/2^31 */
206   /* 0 <= newx <= (1-1/2^31)(2^14-1) + (2^32-1)((2^14-1)-1)/2^31 */
207 
208   x -= qpart*m; *q += qpart;
209   /* x <= 49146 */
210 
211   qpart = (x*(uint64)v)>>31;
212   /* 0 <= newx <= (1-1/2^31)m + x(m-1)/2^31 */
213   /* 0 <= newx <= m + 49146(2^14-1)/2^31 */
214   /* 0 <= newx <= m + 0.4 */
215   /* 0 <= newx <= m */
216 
217   x -= qpart*m; *q += qpart;
218   /* x <= m */
219 
220   x -= m; *q += 1;
221   mask = -(x>>31);
222   x += mask&(uint32)m; *q += mask;
223   /* x < m */
224 
225   *r = x;
226 }
227 
228 uint32 uint32_div_uint14(uint32 x,uint16 m)
229 {
230   uint32 q;
231   uint16 r;
232   uint32_divmod_uint14(&q,&r,x,m);
233   return q;
234 }
235 
236 uint16 uint32_mod_uint14(uint32 x,uint16 m)
237 {
238   uint32 q;
239   uint16 r;
240   uint32_divmod_uint14(&q,&r,x,m);
241   return r;
242 }
243 
244 /* from supercop-20201130/crypto_kem/sntrup761/ref/int32.c */
245 
246 static void int32_divmod_uint14(int32 *q,uint16 *r,int32 x,uint16 m)
247 {
248   uint32 uq,uq2;
249   uint16 ur,ur2;
250   uint32 mask;
251 
252   uint32_divmod_uint14(&uq,&ur,0x80000000+(uint32)x,m);
253   uint32_divmod_uint14(&uq2,&ur2,0x80000000,m);
254   ur -= ur2; uq -= uq2;
255   mask = -(uint32)(ur>>15);
256   ur += mask&m; uq += mask;
257   *r = ur; *q = uq;
258 }
259 
260 int32 int32_div_uint14(int32 x,uint16 m)
261 {
262   int32 q;
263   uint16 r;
264   int32_divmod_uint14(&q,&r,x,m);
265   return q;
266 }
267 
268 uint16 int32_mod_uint14(int32 x,uint16 m)
269 {
270   int32 q;
271   uint16 r;
272   int32_divmod_uint14(&q,&r,x,m);
273   return r;
274 }
275 
276 /* from supercop-20201130/crypto_kem/sntrup761/ref/paramsmenu.h */
277 /* pick one of these three: */
278 #define SIZE761
279 #undef SIZE653
280 #undef SIZE857
281 
282 /* pick one of these two: */
283 #define SNTRUP /* Streamlined NTRU Prime */
284 #undef LPR /* NTRU LPRime */
285 
286 /* from supercop-20201130/crypto_kem/sntrup761/ref/params.h */
287 #ifndef params_H
288 #define params_H
289 
290 /* menu of parameter choices: */
291 
292 
293 /* what the menu means: */
294 
295 #if defined(SIZE761)
296 #define p 761
297 #define q 4591
298 #define Rounded_bytes 1007
299 #ifndef LPR
300 #define Rq_bytes 1158
301 #define w 286
302 #else
303 #define w 250
304 #define tau0 2156
305 #define tau1 114
306 #define tau2 2007
307 #define tau3 287
308 #endif
309 
310 #elif defined(SIZE653)
311 #define p 653
312 #define q 4621
313 #define Rounded_bytes 865
314 #ifndef LPR
315 #define Rq_bytes 994
316 #define w 288
317 #else
318 #define w 252
319 #define tau0 2175
320 #define tau1 113
321 #define tau2 2031
322 #define tau3 290
323 #endif
324 
325 #elif defined(SIZE857)
326 #define p 857
327 #define q 5167
328 #define Rounded_bytes 1152
329 #ifndef LPR
330 #define Rq_bytes 1322
331 #define w 322
332 #else
333 #define w 281
334 #define tau0 2433
335 #define tau1 101
336 #define tau2 2265
337 #define tau3 324
338 #endif
339 
340 #else
341 #error "no parameter set defined"
342 #endif
343 
344 #ifdef LPR
345 #define I 256
346 #endif
347 
348 #endif
349 
350 /* from supercop-20201130/crypto_kem/sntrup761/ref/Decode.h */
351 #ifndef Decode_H
352 #define Decode_H
353 
354 #define Decode CRYPTO_NAMESPACE(Decode)
355 
356 /* Decode(R,s,M,len) */
357 /* assumes 0 < M[i] < 16384 */
358 /* produces 0 <= R[i] < M[i] */
359 static void Decode(uint16 *,const unsigned char *,const uint16 *,long long);
360 
361 #endif
362 
363 /* from supercop-20201130/crypto_kem/sntrup761/ref/Decode.c */
364 
365 static void Decode(uint16 *out,const unsigned char *S,const uint16 *M,long long len)
366 {
367   if (len == 1) {
368     if (M[0] == 1)
369       *out = 0;
370     else if (M[0] <= 256)
371       *out = uint32_mod_uint14(S[0],M[0]);
372     else
373       *out = uint32_mod_uint14(S[0]+(((uint16)S[1])<<8),M[0]);
374   }
375   if (len > 1) {
376     uint16 R2[(len+1)/2];
377     uint16 M2[(len+1)/2];
378     uint16 bottomr[len/2];
379     uint32 bottomt[len/2];
380     long long i;
381     for (i = 0;i < len-1;i += 2) {
382       uint32 m = M[i]*(uint32) M[i+1];
383       if (m > 256*16383) {
384         bottomt[i/2] = 256*256;
385         bottomr[i/2] = S[0]+256*S[1];
386         S += 2;
387         M2[i/2] = (((m+255)>>8)+255)>>8;
388       } else if (m >= 16384) {
389         bottomt[i/2] = 256;
390         bottomr[i/2] = S[0];
391         S += 1;
392         M2[i/2] = (m+255)>>8;
393       } else {
394         bottomt[i/2] = 1;
395         bottomr[i/2] = 0;
396         M2[i/2] = m;
397       }
398     }
399     if (i < len)
400       M2[i/2] = M[i];
401     Decode(R2,S,M2,(len+1)/2);
402     for (i = 0;i < len-1;i += 2) {
403       uint32 r = bottomr[i/2];
404       uint32 r1;
405       uint16 r0;
406       r += bottomt[i/2]*R2[i/2];
407       uint32_divmod_uint14(&r1,&r0,r,M[i]);
408       r1 = uint32_mod_uint14(r1,M[i+1]); /* only needed for invalid inputs */
409       *out++ = r0;
410       *out++ = r1;
411     }
412     if (i < len)
413       *out++ = R2[i/2];
414   }
415 }
416 
417 /* from supercop-20201130/crypto_kem/sntrup761/ref/Encode.h */
418 #ifndef Encode_H
419 #define Encode_H
420 
421 #define Encode CRYPTO_NAMESPACE(Encode)
422 
423 /* Encode(s,R,M,len) */
424 /* assumes 0 <= R[i] < M[i] < 16384 */
425 static void Encode(unsigned char *,const uint16 *,const uint16 *,long long);
426 
427 #endif
428 
429 /* from supercop-20201130/crypto_kem/sntrup761/ref/Encode.c */
430 
431 /* 0 <= R[i] < M[i] < 16384 */
432 static void Encode(unsigned char *out,const uint16 *R,const uint16 *M,long long len)
433 {
434   if (len == 1) {
435     uint16 r = R[0];
436     uint16 m = M[0];
437     while (m > 1) {
438       *out++ = r;
439       r >>= 8;
440       m = (m+255)>>8;
441     }
442   }
443   if (len > 1) {
444     uint16 R2[(len+1)/2];
445     uint16 M2[(len+1)/2];
446     long long i;
447     for (i = 0;i < len-1;i += 2) {
448       uint32 m0 = M[i];
449       uint32 r = R[i]+R[i+1]*m0;
450       uint32 m = M[i+1]*m0;
451       while (m >= 16384) {
452         *out++ = r;
453         r >>= 8;
454         m = (m+255)>>8;
455       }
456       R2[i/2] = r;
457       M2[i/2] = m;
458     }
459     if (i < len) {
460       R2[i/2] = R[i];
461       M2[i/2] = M[i];
462     }
463     Encode(out,R2,M2,(len+1)/2);
464   }
465 }
466 
467 /* from supercop-20201130/crypto_kem/sntrup761/ref/kem.c */
468 
469 #ifdef LPR
470 #endif
471 
472 
473 /* ----- masks */
474 
475 #ifndef LPR
476 
477 /* return -1 if x!=0; else return 0 */
478 static int int16_nonzero_mask(int16 x)
479 {
480   uint16 u = x; /* 0, else 1...65535 */
481   uint32 v = u; /* 0, else 1...65535 */
482   v = -v; /* 0, else 2^32-65535...2^32-1 */
483   v >>= 31; /* 0, else 1 */
484   return -v; /* 0, else -1 */
485 }
486 
487 #endif
488 
489 /* return -1 if x<0; otherwise return 0 */
490 static int int16_negative_mask(int16 x)
491 {
492   uint16 u = x;
493   u >>= 15;
494   return -(int) u;
495   /* alternative with gcc -fwrapv: */
496   /* x>>15 compiles to CPU's arithmetic right shift */
497 }
498 
499 /* ----- arithmetic mod 3 */
500 
501 typedef int8 small;
502 
503 /* F3 is always represented as -1,0,1 */
504 /* so ZZ_fromF3 is a no-op */
505 
506 /* x must not be close to top int16 */
507 static small F3_freeze(int16 x)
508 {
509   return int32_mod_uint14(x+1,3)-1;
510 }
511 
512 /* ----- arithmetic mod q */
513 
514 #define q12 ((q-1)/2)
515 typedef int16 Fq;
516 /* always represented as -q12...q12 */
517 /* so ZZ_fromFq is a no-op */
518 
519 /* x must not be close to top int32 */
520 static Fq Fq_freeze(int32 x)
521 {
522   return int32_mod_uint14(x+q12,q)-q12;
523 }
524 
525 #ifndef LPR
526 
527 static Fq Fq_recip(Fq a1)
528 {
529   int i = 1;
530   Fq ai = a1;
531 
532   while (i < q-2) {
533     ai = Fq_freeze(a1*(int32)ai);
534     i += 1;
535   }
536   return ai;
537 }
538 
539 #endif
540 
541 /* ----- Top and Right */
542 
543 #ifdef LPR
544 #define tau 16
545 
546 static int8 Top(Fq C)
547 {
548   return (tau1*(int32)(C+tau0)+16384)>>15;
549 }
550 
551 static Fq Right(int8 T)
552 {
553   return Fq_freeze(tau3*(int32)T-tau2);
554 }
555 #endif
556 
557 /* ----- small polynomials */
558 
559 #ifndef LPR
560 
561 /* 0 if Weightw_is(r), else -1 */
562 static int Weightw_mask(small *r)
563 {
564   int weight = 0;
565   int i;
566 
567   for (i = 0;i < p;++i) weight += r[i]&1;
568   return int16_nonzero_mask(weight-w);
569 }
570 
571 /* R3_fromR(R_fromRq(r)) */
572 static void R3_fromRq(small *out,const Fq *r)
573 {
574   int i;
575   for (i = 0;i < p;++i) out[i] = F3_freeze(r[i]);
576 }
577 
578 /* h = f*g in the ring R3 */
579 static void R3_mult(small *h,const small *f,const small *g)
580 {
581   small fg[p+p-1];
582   small result;
583   int i,j;
584 
585   for (i = 0;i < p;++i) {
586     result = 0;
587     for (j = 0;j <= i;++j) result = F3_freeze(result+f[j]*g[i-j]);
588     fg[i] = result;
589   }
590   for (i = p;i < p+p-1;++i) {
591     result = 0;
592     for (j = i-p+1;j < p;++j) result = F3_freeze(result+f[j]*g[i-j]);
593     fg[i] = result;
594   }
595 
596   for (i = p+p-2;i >= p;--i) {
597     fg[i-p] = F3_freeze(fg[i-p]+fg[i]);
598     fg[i-p+1] = F3_freeze(fg[i-p+1]+fg[i]);
599   }
600 
601   for (i = 0;i < p;++i) h[i] = fg[i];
602 }
603 
604 /* returns 0 if recip succeeded; else -1 */
605 static int R3_recip(small *out,const small *in)
606 {
607   small f[p+1],g[p+1],v[p+1],r[p+1];
608   int i,loop,delta;
609   int sign,swap,t;
610 
611   for (i = 0;i < p+1;++i) v[i] = 0;
612   for (i = 0;i < p+1;++i) r[i] = 0;
613   r[0] = 1;
614   for (i = 0;i < p;++i) f[i] = 0;
615   f[0] = 1; f[p-1] = f[p] = -1;
616   for (i = 0;i < p;++i) g[p-1-i] = in[i];
617   g[p] = 0;
618 
619   delta = 1;
620 
621   for (loop = 0;loop < 2*p-1;++loop) {
622     for (i = p;i > 0;--i) v[i] = v[i-1];
623     v[0] = 0;
624 
625     sign = -g[0]*f[0];
626     swap = int16_negative_mask(-delta) & int16_nonzero_mask(g[0]);
627     delta ^= swap&(delta^-delta);
628     delta += 1;
629 
630     for (i = 0;i < p+1;++i) {
631       t = swap&(f[i]^g[i]); f[i] ^= t; g[i] ^= t;
632       t = swap&(v[i]^r[i]); v[i] ^= t; r[i] ^= t;
633     }
634 
635     for (i = 0;i < p+1;++i) g[i] = F3_freeze(g[i]+sign*f[i]);
636     for (i = 0;i < p+1;++i) r[i] = F3_freeze(r[i]+sign*v[i]);
637 
638     for (i = 0;i < p;++i) g[i] = g[i+1];
639     g[p] = 0;
640   }
641 
642   sign = f[0];
643   for (i = 0;i < p;++i) out[i] = sign*v[p-1-i];
644 
645   return int16_nonzero_mask(delta);
646 }
647 
648 #endif
649 
650 /* ----- polynomials mod q */
651 
652 /* h = f*g in the ring Rq */
653 static void Rq_mult_small(Fq *h,const Fq *f,const small *g)
654 {
655   Fq fg[p+p-1];
656   Fq result;
657   int i,j;
658 
659   for (i = 0;i < p;++i) {
660     result = 0;
661     for (j = 0;j <= i;++j) result = Fq_freeze(result+f[j]*(int32)g[i-j]);
662     fg[i] = result;
663   }
664   for (i = p;i < p+p-1;++i) {
665     result = 0;
666     for (j = i-p+1;j < p;++j) result = Fq_freeze(result+f[j]*(int32)g[i-j]);
667     fg[i] = result;
668   }
669 
670   for (i = p+p-2;i >= p;--i) {
671     fg[i-p] = Fq_freeze(fg[i-p]+fg[i]);
672     fg[i-p+1] = Fq_freeze(fg[i-p+1]+fg[i]);
673   }
674 
675   for (i = 0;i < p;++i) h[i] = fg[i];
676 }
677 
678 #ifndef LPR
679 
680 /* h = 3f in Rq */
681 static void Rq_mult3(Fq *h,const Fq *f)
682 {
683   int i;
684 
685   for (i = 0;i < p;++i) h[i] = Fq_freeze(3*f[i]);
686 }
687 
688 /* out = 1/(3*in) in Rq */
689 /* returns 0 if recip succeeded; else -1 */
690 static int Rq_recip3(Fq *out,const small *in)
691 {
692   Fq f[p+1],g[p+1],v[p+1],r[p+1];
693   int i,loop,delta;
694   int swap,t;
695   int32 f0,g0;
696   Fq scale;
697 
698   for (i = 0;i < p+1;++i) v[i] = 0;
699   for (i = 0;i < p+1;++i) r[i] = 0;
700   r[0] = Fq_recip(3);
701   for (i = 0;i < p;++i) f[i] = 0;
702   f[0] = 1; f[p-1] = f[p] = -1;
703   for (i = 0;i < p;++i) g[p-1-i] = in[i];
704   g[p] = 0;
705 
706   delta = 1;
707 
708   for (loop = 0;loop < 2*p-1;++loop) {
709     for (i = p;i > 0;--i) v[i] = v[i-1];
710     v[0] = 0;
711 
712     swap = int16_negative_mask(-delta) & int16_nonzero_mask(g[0]);
713     delta ^= swap&(delta^-delta);
714     delta += 1;
715 
716     for (i = 0;i < p+1;++i) {
717       t = swap&(f[i]^g[i]); f[i] ^= t; g[i] ^= t;
718       t = swap&(v[i]^r[i]); v[i] ^= t; r[i] ^= t;
719     }
720 
721     f0 = f[0];
722     g0 = g[0];
723     for (i = 0;i < p+1;++i) g[i] = Fq_freeze(f0*g[i]-g0*f[i]);
724     for (i = 0;i < p+1;++i) r[i] = Fq_freeze(f0*r[i]-g0*v[i]);
725 
726     for (i = 0;i < p;++i) g[i] = g[i+1];
727     g[p] = 0;
728   }
729 
730   scale = Fq_recip(f[0]);
731   for (i = 0;i < p;++i) out[i] = Fq_freeze(scale*(int32)v[p-1-i]);
732 
733   return int16_nonzero_mask(delta);
734 }
735 
736 #endif
737 
738 /* ----- rounded polynomials mod q */
739 
740 static void Round(Fq *out,const Fq *a)
741 {
742   int i;
743   for (i = 0;i < p;++i) out[i] = a[i]-F3_freeze(a[i]);
744 }
745 
746 /* ----- sorting to generate short polynomial */
747 
748 static void Short_fromlist(small *out,const uint32 *in)
749 {
750   uint32 L[p];
751   int i;
752 
753   for (i = 0;i < w;++i) L[i] = in[i]&(uint32)-2;
754   for (i = w;i < p;++i) L[i] = (in[i]&(uint32)-3)|1;
755   crypto_sort_uint32(L,p);
756   for (i = 0;i < p;++i) out[i] = (L[i]&3)-1;
757 }
758 
759 /* ----- underlying hash function */
760 
761 #define Hash_bytes 32
762 
763 /* e.g., b = 0 means out = Hash0(in) */
764 static void Hash_prefix(unsigned char *out,int b,const unsigned char *in,int inlen)
765 {
766   unsigned char x[inlen+1];
767   unsigned char h[64];
768   int i;
769 
770   x[0] = b;
771   for (i = 0;i < inlen;++i) x[i+1] = in[i];
772   crypto_hash_sha512(h,x,inlen+1);
773   for (i = 0;i < 32;++i) out[i] = h[i];
774 }
775 
776 /* ----- higher-level randomness */
777 
778 static uint32 urandom32(void)
779 {
780   unsigned char c[4];
781   uint32 out[4];
782 
783   randombytes(c,4);
784   out[0] = (uint32)c[0];
785   out[1] = ((uint32)c[1])<<8;
786   out[2] = ((uint32)c[2])<<16;
787   out[3] = ((uint32)c[3])<<24;
788   return out[0]+out[1]+out[2]+out[3];
789 }
790 
791 static void Short_random(small *out)
792 {
793   uint32 L[p];
794   int i;
795 
796   for (i = 0;i < p;++i) L[i] = urandom32();
797   Short_fromlist(out,L);
798 }
799 
800 #ifndef LPR
801 
802 static void Small_random(small *out)
803 {
804   int i;
805 
806   for (i = 0;i < p;++i) out[i] = (((urandom32()&0x3fffffff)*3)>>30)-1;
807 }
808 
809 #endif
810 
811 /* ----- Streamlined NTRU Prime Core */
812 
813 #ifndef LPR
814 
815 /* h,(f,ginv) = KeyGen() */
816 static void KeyGen(Fq *h,small *f,small *ginv)
817 {
818   small g[p];
819   Fq finv[p];
820 
821   for (;;) {
822     Small_random(g);
823     if (R3_recip(ginv,g) == 0) break;
824   }
825   Short_random(f);
826   Rq_recip3(finv,f); /* always works */
827   Rq_mult_small(h,finv,g);
828 }
829 
830 /* c = Encrypt(r,h) */
831 static void Encrypt(Fq *c,const small *r,const Fq *h)
832 {
833   Fq hr[p];
834 
835   Rq_mult_small(hr,h,r);
836   Round(c,hr);
837 }
838 
839 /* r = Decrypt(c,(f,ginv)) */
840 static void Decrypt(small *r,const Fq *c,const small *f,const small *ginv)
841 {
842   Fq cf[p];
843   Fq cf3[p];
844   small e[p];
845   small ev[p];
846   int mask;
847   int i;
848 
849   Rq_mult_small(cf,c,f);
850   Rq_mult3(cf3,cf);
851   R3_fromRq(e,cf3);
852   R3_mult(ev,e,ginv);
853 
854   mask = Weightw_mask(ev); /* 0 if weight w, else -1 */
855   for (i = 0;i < w;++i) r[i] = ((ev[i]^1)&~mask)^1;
856   for (i = w;i < p;++i) r[i] = ev[i]&~mask;
857 }
858 
859 #endif
860 
861 /* ----- NTRU LPRime Core */
862 
863 #ifdef LPR
864 
865 /* (G,A),a = KeyGen(G); leaves G unchanged */
866 static void KeyGen(Fq *A,small *a,const Fq *G)
867 {
868   Fq aG[p];
869 
870   Short_random(a);
871   Rq_mult_small(aG,G,a);
872   Round(A,aG);
873 }
874 
875 /* B,T = Encrypt(r,(G,A),b) */
876 static void Encrypt(Fq *B,int8 *T,const int8 *r,const Fq *G,const Fq *A,const small *b)
877 {
878   Fq bG[p];
879   Fq bA[p];
880   int i;
881 
882   Rq_mult_small(bG,G,b);
883   Round(B,bG);
884   Rq_mult_small(bA,A,b);
885   for (i = 0;i < I;++i) T[i] = Top(Fq_freeze(bA[i]+r[i]*q12));
886 }
887 
888 /* r = Decrypt((B,T),a) */
889 static void Decrypt(int8 *r,const Fq *B,const int8 *T,const small *a)
890 {
891   Fq aB[p];
892   int i;
893 
894   Rq_mult_small(aB,B,a);
895   for (i = 0;i < I;++i)
896     r[i] = -int16_negative_mask(Fq_freeze(Right(T[i])-aB[i]+4*w+1));
897 }
898 
899 #endif
900 
901 /* ----- encoding I-bit inputs */
902 
903 #ifdef LPR
904 
905 #define Inputs_bytes (I/8)
906 typedef int8 Inputs[I]; /* passed by reference */
907 
908 static void Inputs_encode(unsigned char *s,const Inputs r)
909 {
910   int i;
911   for (i = 0;i < Inputs_bytes;++i) s[i] = 0;
912   for (i = 0;i < I;++i) s[i>>3] |= r[i]<<(i&7);
913 }
914 
915 #endif
916 
917 /* ----- Expand */
918 
919 #ifdef LPR
920 
921 static const unsigned char aes_nonce[16] = {0};
922 
923 static void Expand(uint32 *L,const unsigned char *k)
924 {
925   int i;
926   crypto_stream_aes256ctr((unsigned char *) L,4*p,aes_nonce,k);
927   for (i = 0;i < p;++i) {
928     uint32 L0 = ((unsigned char *) L)[4*i];
929     uint32 L1 = ((unsigned char *) L)[4*i+1];
930     uint32 L2 = ((unsigned char *) L)[4*i+2];
931     uint32 L3 = ((unsigned char *) L)[4*i+3];
932     L[i] = L0+(L1<<8)+(L2<<16)+(L3<<24);
933   }
934 }
935 
936 #endif
937 
938 /* ----- Seeds */
939 
940 #ifdef LPR
941 
942 #define Seeds_bytes 32
943 
944 static void Seeds_random(unsigned char *s)
945 {
946   randombytes(s,Seeds_bytes);
947 }
948 
949 #endif
950 
951 /* ----- Generator, HashShort */
952 
953 #ifdef LPR
954 
955 /* G = Generator(k) */
956 static void Generator(Fq *G,const unsigned char *k)
957 {
958   uint32 L[p];
959   int i;
960 
961   Expand(L,k);
962   for (i = 0;i < p;++i) G[i] = uint32_mod_uint14(L[i],q)-q12;
963 }
964 
965 /* out = HashShort(r) */
966 static void HashShort(small *out,const Inputs r)
967 {
968   unsigned char s[Inputs_bytes];
969   unsigned char h[Hash_bytes];
970   uint32 L[p];
971 
972   Inputs_encode(s,r);
973   Hash_prefix(h,5,s,sizeof s);
974   Expand(L,h);
975   Short_fromlist(out,L);
976 }
977 
978 #endif
979 
980 /* ----- NTRU LPRime Expand */
981 
982 #ifdef LPR
983 
984 /* (S,A),a = XKeyGen() */
985 static void XKeyGen(unsigned char *S,Fq *A,small *a)
986 {
987   Fq G[p];
988 
989   Seeds_random(S);
990   Generator(G,S);
991   KeyGen(A,a,G);
992 }
993 
994 /* B,T = XEncrypt(r,(S,A)) */
995 static void XEncrypt(Fq *B,int8 *T,const int8 *r,const unsigned char *S,const Fq *A)
996 {
997   Fq G[p];
998   small b[p];
999 
1000   Generator(G,S);
1001   HashShort(b,r);
1002   Encrypt(B,T,r,G,A,b);
1003 }
1004 
1005 #define XDecrypt Decrypt
1006 
1007 #endif
1008 
1009 /* ----- encoding small polynomials (including short polynomials) */
1010 
1011 #define Small_bytes ((p+3)/4)
1012 
1013 /* these are the only functions that rely on p mod 4 = 1 */
1014 
1015 static void Small_encode(unsigned char *s,const small *f)
1016 {
1017   small x;
1018   int i;
1019 
1020   for (i = 0;i < p/4;++i) {
1021     x = *f++ + 1;
1022     x += (*f++ + 1)<<2;
1023     x += (*f++ + 1)<<4;
1024     x += (*f++ + 1)<<6;
1025     *s++ = x;
1026   }
1027   x = *f++ + 1;
1028   *s++ = x;
1029 }
1030 
1031 static void Small_decode(small *f,const unsigned char *s)
1032 {
1033   unsigned char x;
1034   int i;
1035 
1036   for (i = 0;i < p/4;++i) {
1037     x = *s++;
1038     *f++ = ((small)(x&3))-1; x >>= 2;
1039     *f++ = ((small)(x&3))-1; x >>= 2;
1040     *f++ = ((small)(x&3))-1; x >>= 2;
1041     *f++ = ((small)(x&3))-1;
1042   }
1043   x = *s++;
1044   *f++ = ((small)(x&3))-1;
1045 }
1046 
1047 /* ----- encoding general polynomials */
1048 
1049 #ifndef LPR
1050 
1051 static void Rq_encode(unsigned char *s,const Fq *r)
1052 {
1053   uint16 R[p],M[p];
1054   int i;
1055 
1056   for (i = 0;i < p;++i) R[i] = r[i]+q12;
1057   for (i = 0;i < p;++i) M[i] = q;
1058   Encode(s,R,M,p);
1059 }
1060 
1061 static void Rq_decode(Fq *r,const unsigned char *s)
1062 {
1063   uint16 R[p],M[p];
1064   int i;
1065 
1066   for (i = 0;i < p;++i) M[i] = q;
1067   Decode(R,s,M,p);
1068   for (i = 0;i < p;++i) r[i] = ((Fq)R[i])-q12;
1069 }
1070 
1071 #endif
1072 
1073 /* ----- encoding rounded polynomials */
1074 
1075 static void Rounded_encode(unsigned char *s,const Fq *r)
1076 {
1077   uint16 R[p],M[p];
1078   int i;
1079 
1080   for (i = 0;i < p;++i) R[i] = ((r[i]+q12)*10923)>>15;
1081   for (i = 0;i < p;++i) M[i] = (q+2)/3;
1082   Encode(s,R,M,p);
1083 }
1084 
1085 static void Rounded_decode(Fq *r,const unsigned char *s)
1086 {
1087   uint16 R[p],M[p];
1088   int i;
1089 
1090   for (i = 0;i < p;++i) M[i] = (q+2)/3;
1091   Decode(R,s,M,p);
1092   for (i = 0;i < p;++i) r[i] = R[i]*3-q12;
1093 }
1094 
1095 /* ----- encoding top polynomials */
1096 
1097 #ifdef LPR
1098 
1099 #define Top_bytes (I/2)
1100 
1101 static void Top_encode(unsigned char *s,const int8 *T)
1102 {
1103   int i;
1104   for (i = 0;i < Top_bytes;++i)
1105     s[i] = T[2*i]+(T[2*i+1]<<4);
1106 }
1107 
1108 static void Top_decode(int8 *T,const unsigned char *s)
1109 {
1110   int i;
1111   for (i = 0;i < Top_bytes;++i) {
1112     T[2*i] = s[i]&15;
1113     T[2*i+1] = s[i]>>4;
1114   }
1115 }
1116 
1117 #endif
1118 
1119 /* ----- Streamlined NTRU Prime Core plus encoding */
1120 
1121 #ifndef LPR
1122 
1123 typedef small Inputs[p]; /* passed by reference */
1124 #define Inputs_random Short_random
1125 #define Inputs_encode Small_encode
1126 #define Inputs_bytes Small_bytes
1127 
1128 #define Ciphertexts_bytes Rounded_bytes
1129 #define SecretKeys_bytes (2*Small_bytes)
1130 #define PublicKeys_bytes Rq_bytes
1131 
1132 /* pk,sk = ZKeyGen() */
1133 static void ZKeyGen(unsigned char *pk,unsigned char *sk)
1134 {
1135   Fq h[p];
1136   small f[p],v[p];
1137 
1138   KeyGen(h,f,v);
1139   Rq_encode(pk,h);
1140   Small_encode(sk,f); sk += Small_bytes;
1141   Small_encode(sk,v);
1142 }
1143 
1144 /* C = ZEncrypt(r,pk) */
1145 static void ZEncrypt(unsigned char *C,const Inputs r,const unsigned char *pk)
1146 {
1147   Fq h[p];
1148   Fq c[p];
1149   Rq_decode(h,pk);
1150   Encrypt(c,r,h);
1151   Rounded_encode(C,c);
1152 }
1153 
1154 /* r = ZDecrypt(C,sk) */
1155 static void ZDecrypt(Inputs r,const unsigned char *C,const unsigned char *sk)
1156 {
1157   small f[p],v[p];
1158   Fq c[p];
1159 
1160   Small_decode(f,sk); sk += Small_bytes;
1161   Small_decode(v,sk);
1162   Rounded_decode(c,C);
1163   Decrypt(r,c,f,v);
1164 }
1165 
1166 #endif
1167 
1168 /* ----- NTRU LPRime Expand plus encoding */
1169 
1170 #ifdef LPR
1171 
1172 #define Ciphertexts_bytes (Rounded_bytes+Top_bytes)
1173 #define SecretKeys_bytes Small_bytes
1174 #define PublicKeys_bytes (Seeds_bytes+Rounded_bytes)
1175 
1176 static void Inputs_random(Inputs r)
1177 {
1178   unsigned char s[Inputs_bytes];
1179   int i;
1180 
1181   randombytes(s,sizeof s);
1182   for (i = 0;i < I;++i) r[i] = 1&(s[i>>3]>>(i&7));
1183 }
1184 
1185 /* pk,sk = ZKeyGen() */
1186 static void ZKeyGen(unsigned char *pk,unsigned char *sk)
1187 {
1188   Fq A[p];
1189   small a[p];
1190 
1191   XKeyGen(pk,A,a); pk += Seeds_bytes;
1192   Rounded_encode(pk,A);
1193   Small_encode(sk,a);
1194 }
1195 
1196 /* c = ZEncrypt(r,pk) */
1197 static void ZEncrypt(unsigned char *c,const Inputs r,const unsigned char *pk)
1198 {
1199   Fq A[p];
1200   Fq B[p];
1201   int8 T[I];
1202 
1203   Rounded_decode(A,pk+Seeds_bytes);
1204   XEncrypt(B,T,r,pk,A);
1205   Rounded_encode(c,B); c += Rounded_bytes;
1206   Top_encode(c,T);
1207 }
1208 
1209 /* r = ZDecrypt(C,sk) */
1210 static void ZDecrypt(Inputs r,const unsigned char *c,const unsigned char *sk)
1211 {
1212   small a[p];
1213   Fq B[p];
1214   int8 T[I];
1215 
1216   Small_decode(a,sk);
1217   Rounded_decode(B,c);
1218   Top_decode(T,c+Rounded_bytes);
1219   XDecrypt(r,B,T,a);
1220 }
1221 
1222 #endif
1223 
1224 /* ----- confirmation hash */
1225 
1226 #define Confirm_bytes 32
1227 
1228 /* h = HashConfirm(r,pk,cache); cache is Hash4(pk) */
1229 static void HashConfirm(unsigned char *h,const unsigned char *r,const unsigned char *pk,const unsigned char *cache)
1230 {
1231 #ifndef LPR
1232   unsigned char x[Hash_bytes*2];
1233   int i;
1234 
1235   Hash_prefix(x,3,r,Inputs_bytes);
1236   for (i = 0;i < Hash_bytes;++i) x[Hash_bytes+i] = cache[i];
1237 #else
1238   unsigned char x[Inputs_bytes+Hash_bytes];
1239   int i;
1240 
1241   for (i = 0;i < Inputs_bytes;++i) x[i] = r[i];
1242   for (i = 0;i < Hash_bytes;++i) x[Inputs_bytes+i] = cache[i];
1243 #endif
1244   Hash_prefix(h,2,x,sizeof x);
1245 }
1246 
1247 /* ----- session-key hash */
1248 
1249 /* k = HashSession(b,y,z) */
1250 static void HashSession(unsigned char *k,int b,const unsigned char *y,const unsigned char *z)
1251 {
1252 #ifndef LPR
1253   unsigned char x[Hash_bytes+Ciphertexts_bytes+Confirm_bytes];
1254   int i;
1255 
1256   Hash_prefix(x,3,y,Inputs_bytes);
1257   for (i = 0;i < Ciphertexts_bytes+Confirm_bytes;++i) x[Hash_bytes+i] = z[i];
1258 #else
1259   unsigned char x[Inputs_bytes+Ciphertexts_bytes+Confirm_bytes];
1260   int i;
1261 
1262   for (i = 0;i < Inputs_bytes;++i) x[i] = y[i];
1263   for (i = 0;i < Ciphertexts_bytes+Confirm_bytes;++i) x[Inputs_bytes+i] = z[i];
1264 #endif
1265   Hash_prefix(k,b,x,sizeof x);
1266 }
1267 
1268 /* ----- Streamlined NTRU Prime and NTRU LPRime */
1269 
1270 /* pk,sk = KEM_KeyGen() */
1271 static void KEM_KeyGen(unsigned char *pk,unsigned char *sk)
1272 {
1273   int i;
1274 
1275   ZKeyGen(pk,sk); sk += SecretKeys_bytes;
1276   for (i = 0;i < PublicKeys_bytes;++i) *sk++ = pk[i];
1277   randombytes(sk,Inputs_bytes); sk += Inputs_bytes;
1278   Hash_prefix(sk,4,pk,PublicKeys_bytes);
1279 }
1280 
1281 /* c,r_enc = Hide(r,pk,cache); cache is Hash4(pk) */
1282 static void Hide(unsigned char *c,unsigned char *r_enc,const Inputs r,const unsigned char *pk,const unsigned char *cache)
1283 {
1284   Inputs_encode(r_enc,r);
1285   ZEncrypt(c,r,pk); c += Ciphertexts_bytes;
1286   HashConfirm(c,r_enc,pk,cache);
1287 }
1288 
1289 /* c,k = Encap(pk) */
1290 static void Encap(unsigned char *c,unsigned char *k,const unsigned char *pk)
1291 {
1292   Inputs r;
1293   unsigned char r_enc[Inputs_bytes];
1294   unsigned char cache[Hash_bytes];
1295 
1296   Hash_prefix(cache,4,pk,PublicKeys_bytes);
1297   Inputs_random(r);
1298   Hide(c,r_enc,r,pk,cache);
1299   HashSession(k,1,r_enc,c);
1300 }
1301 
1302 /* 0 if matching ciphertext+confirm, else -1 */
1303 static int Ciphertexts_diff_mask(const unsigned char *c,const unsigned char *c2)
1304 {
1305   uint16 differentbits = 0;
1306   int len = Ciphertexts_bytes+Confirm_bytes;
1307 
1308   while (len-- > 0) differentbits |= (*c++)^(*c2++);
1309   return (1&((differentbits-1)>>8))-1;
1310 }
1311 
1312 /* k = Decap(c,sk) */
1313 static void Decap(unsigned char *k,const unsigned char *c,const unsigned char *sk)
1314 {
1315   const unsigned char *pk = sk + SecretKeys_bytes;
1316   const unsigned char *rho = pk + PublicKeys_bytes;
1317   const unsigned char *cache = rho + Inputs_bytes;
1318   Inputs r;
1319   unsigned char r_enc[Inputs_bytes];
1320   unsigned char cnew[Ciphertexts_bytes+Confirm_bytes];
1321   int mask;
1322   int i;
1323 
1324   ZDecrypt(r,c,sk);
1325   Hide(cnew,r_enc,r,pk,cache);
1326   mask = Ciphertexts_diff_mask(c,cnew);
1327   for (i = 0;i < Inputs_bytes;++i) r_enc[i] ^= mask&(r_enc[i]^rho[i]);
1328   HashSession(k,1+mask,r_enc,c);
1329 }
1330 
1331 /* ----- crypto_kem API */
1332 
1333 
1334 int crypto_kem_sntrup761_keypair(unsigned char *pk,unsigned char *sk)
1335 {
1336   KEM_KeyGen(pk,sk);
1337   return 0;
1338 }
1339 
1340 int crypto_kem_sntrup761_enc(unsigned char *c,unsigned char *k,const unsigned char *pk)
1341 {
1342   Encap(c,k,pk);
1343   return 0;
1344 }
1345 
1346 int crypto_kem_sntrup761_dec(unsigned char *k,const unsigned char *c,const unsigned char *sk)
1347 {
1348   Decap(k,c,sk);
1349   return 0;
1350 }
1351 
1352