xref: /netbsd-src/crypto/external/bsd/openssh/dist/sntrup761.c (revision 2718af68c3efc72c9769069b5c7f9ed36f6b9def)
1 /*	$NetBSD: sntrup761.c,v 1.2 2021/09/03 10:30:33 christos Exp $	*/
2 /*  $OpenBSD: sntrup761.c,v 1.5 2021/01/08 02:33:13 dtucker Exp $ */
3 
4 /*
5  * Public Domain, Authors:
6  * - Daniel J. Bernstein
7  * - Chitchanok Chuengsatiansup
8  * - Tanja Lange
9  * - Christine van Vredendaal
10  */
11 #include "includes.h"
12 __RCSID("$NetBSD: sntrup761.c,v 1.2 2021/09/03 10:30:33 christos Exp $");
13 
14 #include <string.h>
15 #include "crypto_api.h"
16 
17 #define int8 crypto_int8
18 #define uint8 crypto_uint8
19 #define int16 crypto_int16
20 #define uint16 crypto_uint16
21 #define int32 crypto_int32
22 #define uint32 crypto_uint32
23 #define int64 crypto_int64
24 #define uint64 crypto_uint64
25 
26 /* from supercop-20201130/crypto_sort/int32/portable4/int32_minmax.inc */
27 #define int32_MINMAX(a,b) \
28 do { \
29   int64_t ab = (int64_t)b ^ (int64_t)a; \
30   int64_t c = (int64_t)b - (int64_t)a; \
31   c ^= ab & (c ^ b); \
32   c >>= 31; \
33   c &= ab; \
34   a ^= c; \
35   b ^= c; \
36 } while(0)
37 
38 /* from supercop-20201130/crypto_sort/int32/portable4/sort.c */
39 
40 
41 static void crypto_sort_int32(void *array,long long n)
42 {
43   long long top,p,q,r,i,j;
44   int32 *x = array;
45 
46   if (n < 2) return;
47   top = 1;
48   while (top < n - top) top += top;
49 
50   for (p = top;p >= 1;p >>= 1) {
51     i = 0;
52     while (i + 2 * p <= n) {
53       for (j = i;j < i + p;++j)
54         int32_MINMAX(x[j],x[j+p]);
55       i += 2 * p;
56     }
57     for (j = i;j < n - p;++j)
58       int32_MINMAX(x[j],x[j+p]);
59 
60     i = 0;
61     j = 0;
62     for (q = top;q > p;q >>= 1) {
63       if (j != i) for (;;) {
64         if (j == n - q) goto done;
65         int32 a = x[j + p];
66         for (r = q;r > p;r >>= 1)
67           int32_MINMAX(a,x[j + r]);
68         x[j + p] = a;
69         ++j;
70         if (j == i + p) {
71           i += 2 * p;
72           break;
73         }
74       }
75       while (i + p <= n - q) {
76         for (j = i;j < i + p;++j) {
77           int32 a = x[j + p];
78           for (r = q;r > p;r >>= 1)
79             int32_MINMAX(a,x[j+r]);
80           x[j + p] = a;
81         }
82         i += 2 * p;
83       }
84       /* now i + p > n - q */
85       j = i;
86       while (j < n - q) {
87         int32 a = x[j + p];
88         for (r = q;r > p;r >>= 1)
89           int32_MINMAX(a,x[j+r]);
90         x[j + p] = a;
91         ++j;
92       }
93 
94       done: ;
95     }
96   }
97 }
98 
99 /* from supercop-20201130/crypto_sort/uint32/useint32/sort.c */
100 
101 /* can save time by vectorizing xor loops */
102 /* can save time by integrating xor loops with int32_sort */
103 
104 static void crypto_sort_uint32(void *array,long long n)
105 {
106   crypto_uint32 *x = array;
107   long long j;
108   for (j = 0;j < n;++j) x[j] ^= 0x80000000;
109   crypto_sort_int32(array,n);
110   for (j = 0;j < n;++j) x[j] ^= 0x80000000;
111 }
112 
113 /* from supercop-20201130/crypto_kem/sntrup761/ref/uint32.c */
114 
115 /*
116 CPU division instruction typically takes time depending on x.
117 This software is designed to take time independent of x.
118 Time still varies depending on m; user must ensure that m is constant.
119 Time also varies on CPUs where multiplication is variable-time.
120 There could be more CPU issues.
121 There could also be compiler issues.
122 */
123 
124 static void uint32_divmod_uint14(uint32 *q,uint16 *r,uint32 x,uint16 m)
125 {
126   uint32 v = 0x80000000;
127   uint32 qpart;
128   uint32 mask;
129 
130   v /= m;
131 
132   /* caller guarantees m > 0 */
133   /* caller guarantees m < 16384 */
134   /* vm <= 2^31 <= vm+m-1 */
135   /* xvm <= 2^31 x <= xvm+x(m-1) */
136 
137   *q = 0;
138 
139   qpart = (x*(uint64)v)>>31;
140   /* 2^31 qpart <= xv <= 2^31 qpart + 2^31-1 */
141   /* 2^31 qpart m <= xvm <= 2^31 qpart m + (2^31-1)m */
142   /* 2^31 qpart m <= 2^31 x <= 2^31 qpart m + (2^31-1)m + x(m-1) */
143   /* 0 <= 2^31 newx <= (2^31-1)m + x(m-1) */
144   /* 0 <= newx <= (1-1/2^31)m + x(m-1)/2^31 */
145   /* 0 <= newx <= (1-1/2^31)(2^14-1) + (2^32-1)((2^14-1)-1)/2^31 */
146 
147   x -= qpart*m; *q += qpart;
148   /* x <= 49146 */
149 
150   qpart = (x*(uint64)v)>>31;
151   /* 0 <= newx <= (1-1/2^31)m + x(m-1)/2^31 */
152   /* 0 <= newx <= m + 49146(2^14-1)/2^31 */
153   /* 0 <= newx <= m + 0.4 */
154   /* 0 <= newx <= m */
155 
156   x -= qpart*m; *q += qpart;
157   /* x <= m */
158 
159   x -= m; *q += 1;
160   mask = -(x>>31);
161   x += mask&(uint32)m; *q += mask;
162   /* x < m */
163 
164   *r = x;
165 }
166 
167 
168 static uint16 uint32_mod_uint14(uint32 x,uint16 m)
169 {
170   uint32 q;
171   uint16 r;
172   uint32_divmod_uint14(&q,&r,x,m);
173   return r;
174 }
175 
176 /* from supercop-20201130/crypto_kem/sntrup761/ref/int32.c */
177 
178 static void int32_divmod_uint14(int32 *q,uint16 *r,int32 x,uint16 m)
179 {
180   uint32 uq,uq2;
181   uint16 ur,ur2;
182   uint32 mask;
183 
184   uint32_divmod_uint14(&uq,&ur,0x80000000+(uint32)x,m);
185   uint32_divmod_uint14(&uq2,&ur2,0x80000000,m);
186   ur -= ur2; uq -= uq2;
187   mask = -(uint32)(ur>>15);
188   ur += mask&m; uq += mask;
189   *r = ur; *q = uq;
190 }
191 
192 
193 static uint16 int32_mod_uint14(int32 x,uint16 m)
194 {
195   int32 q;
196   uint16 r;
197   int32_divmod_uint14(&q,&r,x,m);
198   return r;
199 }
200 
201 /* from supercop-20201130/crypto_kem/sntrup761/ref/paramsmenu.h */
202 /* pick one of these three: */
203 #define SIZE761
204 #undef SIZE653
205 #undef SIZE857
206 
207 /* pick one of these two: */
208 #define SNTRUP /* Streamlined NTRU Prime */
209 #undef LPR /* NTRU LPRime */
210 
211 /* from supercop-20201130/crypto_kem/sntrup761/ref/params.h */
212 #ifndef params_H
213 #define params_H
214 
215 /* menu of parameter choices: */
216 
217 
218 /* what the menu means: */
219 
220 #if defined(SIZE761)
221 #define p 761
222 #define q 4591
223 #define Rounded_bytes 1007
224 #ifndef LPR
225 #define Rq_bytes 1158
226 #define w 286
227 #else
228 #define w 250
229 #define tau0 2156
230 #define tau1 114
231 #define tau2 2007
232 #define tau3 287
233 #endif
234 
235 #elif defined(SIZE653)
236 #define p 653
237 #define q 4621
238 #define Rounded_bytes 865
239 #ifndef LPR
240 #define Rq_bytes 994
241 #define w 288
242 #else
243 #define w 252
244 #define tau0 2175
245 #define tau1 113
246 #define tau2 2031
247 #define tau3 290
248 #endif
249 
250 #elif defined(SIZE857)
251 #define p 857
252 #define q 5167
253 #define Rounded_bytes 1152
254 #ifndef LPR
255 #define Rq_bytes 1322
256 #define w 322
257 #else
258 #define w 281
259 #define tau0 2433
260 #define tau1 101
261 #define tau2 2265
262 #define tau3 324
263 #endif
264 
265 #else
266 #error "no parameter set defined"
267 #endif
268 
269 #ifdef LPR
270 #define I 256
271 #endif
272 
273 #endif
274 
275 /* from supercop-20201130/crypto_kem/sntrup761/ref/Decode.h */
276 #ifndef Decode_H
277 #define Decode_H
278 
279 
280 /* Decode(R,s,M,len) */
281 /* assumes 0 < M[i] < 16384 */
282 /* produces 0 <= R[i] < M[i] */
283 
284 #endif
285 
286 /* from supercop-20201130/crypto_kem/sntrup761/ref/Decode.c */
287 
288 static void Decode(uint16 *out,const unsigned char *S,const uint16 *M,long long len)
289 {
290   if (len == 1) {
291     if (M[0] == 1)
292       *out = 0;
293     else if (M[0] <= 256)
294       *out = uint32_mod_uint14(S[0],M[0]);
295     else
296       *out = uint32_mod_uint14(S[0]+(((uint16)S[1])<<8),M[0]);
297   }
298   if (len > 1) {
299     uint16 R2[(len+1)/2];
300     uint16 M2[(len+1)/2];
301     uint16 bottomr[len/2];
302     uint32 bottomt[len/2];
303     long long i;
304     for (i = 0;i < len-1;i += 2) {
305       uint32 m = M[i]*(uint32) M[i+1];
306       if (m > 256*16383) {
307         bottomt[i/2] = 256*256;
308         bottomr[i/2] = S[0]+256*S[1];
309         S += 2;
310         M2[i/2] = (((m+255)>>8)+255)>>8;
311       } else if (m >= 16384) {
312         bottomt[i/2] = 256;
313         bottomr[i/2] = S[0];
314         S += 1;
315         M2[i/2] = (m+255)>>8;
316       } else {
317         bottomt[i/2] = 1;
318         bottomr[i/2] = 0;
319         M2[i/2] = m;
320       }
321     }
322     if (i < len)
323       M2[i/2] = M[i];
324     Decode(R2,S,M2,(len+1)/2);
325     for (i = 0;i < len-1;i += 2) {
326       uint32 r = bottomr[i/2];
327       uint32 r1;
328       uint16 r0;
329       r += bottomt[i/2]*R2[i/2];
330       uint32_divmod_uint14(&r1,&r0,r,M[i]);
331       r1 = uint32_mod_uint14(r1,M[i+1]); /* only needed for invalid inputs */
332       *out++ = r0;
333       *out++ = r1;
334     }
335     if (i < len)
336       *out++ = R2[i/2];
337   }
338 }
339 
340 /* from supercop-20201130/crypto_kem/sntrup761/ref/Encode.h */
341 #ifndef Encode_H
342 #define Encode_H
343 
344 
345 /* Encode(s,R,M,len) */
346 /* assumes 0 <= R[i] < M[i] < 16384 */
347 
348 #endif
349 
350 /* from supercop-20201130/crypto_kem/sntrup761/ref/Encode.c */
351 
352 /* 0 <= R[i] < M[i] < 16384 */
353 static void Encode(unsigned char *out,const uint16 *R,const uint16 *M,long long len)
354 {
355   if (len == 1) {
356     uint16 r = R[0];
357     uint16 m = M[0];
358     while (m > 1) {
359       *out++ = r;
360       r >>= 8;
361       m = (m+255)>>8;
362     }
363   }
364   if (len > 1) {
365     uint16 R2[(len+1)/2];
366     uint16 M2[(len+1)/2];
367     long long i;
368     for (i = 0;i < len-1;i += 2) {
369       uint32 m0 = M[i];
370       uint32 r = R[i]+R[i+1]*m0;
371       uint32 m = M[i+1]*m0;
372       while (m >= 16384) {
373         *out++ = r;
374         r >>= 8;
375         m = (m+255)>>8;
376       }
377       R2[i/2] = r;
378       M2[i/2] = m;
379     }
380     if (i < len) {
381       R2[i/2] = R[i];
382       M2[i/2] = M[i];
383     }
384     Encode(out,R2,M2,(len+1)/2);
385   }
386 }
387 
388 /* from supercop-20201130/crypto_kem/sntrup761/ref/kem.c */
389 
390 #ifdef LPR
391 #endif
392 
393 
394 /* ----- masks */
395 
396 #ifndef LPR
397 
398 /* return -1 if x!=0; else return 0 */
399 static int int16_nonzero_mask(int16 x)
400 {
401   uint16 u = x; /* 0, else 1...65535 */
402   uint32 v = u; /* 0, else 1...65535 */
403   v = -v; /* 0, else 2^32-65535...2^32-1 */
404   v >>= 31; /* 0, else 1 */
405   return -v; /* 0, else -1 */
406 }
407 
408 #endif
409 
410 /* return -1 if x<0; otherwise return 0 */
411 static int int16_negative_mask(int16 x)
412 {
413   uint16 u = x;
414   u >>= 15;
415   return -(int) u;
416   /* alternative with gcc -fwrapv: */
417   /* x>>15 compiles to CPU's arithmetic right shift */
418 }
419 
420 /* ----- arithmetic mod 3 */
421 
422 typedef int8 small;
423 
424 /* F3 is always represented as -1,0,1 */
425 /* so ZZ_fromF3 is a no-op */
426 
427 /* x must not be close to top int16 */
428 static small F3_freeze(int16 x)
429 {
430   return int32_mod_uint14(x+1,3)-1;
431 }
432 
433 /* ----- arithmetic mod q */
434 
435 #define q12 ((q-1)/2)
436 typedef int16 Fq;
437 /* always represented as -q12...q12 */
438 /* so ZZ_fromFq is a no-op */
439 
440 /* x must not be close to top int32 */
441 static Fq Fq_freeze(int32 x)
442 {
443   return int32_mod_uint14(x+q12,q)-q12;
444 }
445 
446 #ifndef LPR
447 
448 static Fq Fq_recip(Fq a1)
449 {
450   int i = 1;
451   Fq ai = a1;
452 
453   while (i < q-2) {
454     ai = Fq_freeze(a1*(int32)ai);
455     i += 1;
456   }
457   return ai;
458 }
459 
460 #endif
461 
462 /* ----- Top and Right */
463 
464 #ifdef LPR
465 #define tau 16
466 
467 static int8 Top(Fq C)
468 {
469   return (tau1*(int32)(C+tau0)+16384)>>15;
470 }
471 
472 static Fq Right(int8 T)
473 {
474   return Fq_freeze(tau3*(int32)T-tau2);
475 }
476 #endif
477 
478 /* ----- small polynomials */
479 
480 #ifndef LPR
481 
482 /* 0 if Weightw_is(r), else -1 */
483 static int Weightw_mask(small *r)
484 {
485   int weight = 0;
486   int i;
487 
488   for (i = 0;i < p;++i) weight += r[i]&1;
489   return int16_nonzero_mask(weight-w);
490 }
491 
492 /* R3_fromR(R_fromRq(r)) */
493 static void R3_fromRq(small *out,const Fq *r)
494 {
495   int i;
496   for (i = 0;i < p;++i) out[i] = F3_freeze(r[i]);
497 }
498 
499 /* h = f*g in the ring R3 */
500 static void R3_mult(small *h,const small *f,const small *g)
501 {
502   small fg[p+p-1];
503   small result;
504   int i,j;
505 
506   for (i = 0;i < p;++i) {
507     result = 0;
508     for (j = 0;j <= i;++j) result = F3_freeze(result+f[j]*g[i-j]);
509     fg[i] = result;
510   }
511   for (i = p;i < p+p-1;++i) {
512     result = 0;
513     for (j = i-p+1;j < p;++j) result = F3_freeze(result+f[j]*g[i-j]);
514     fg[i] = result;
515   }
516 
517   for (i = p+p-2;i >= p;--i) {
518     fg[i-p] = F3_freeze(fg[i-p]+fg[i]);
519     fg[i-p+1] = F3_freeze(fg[i-p+1]+fg[i]);
520   }
521 
522   for (i = 0;i < p;++i) h[i] = fg[i];
523 }
524 
525 /* returns 0 if recip succeeded; else -1 */
526 static int R3_recip(small *out,const small *in)
527 {
528   small f[p+1],g[p+1],v[p+1],r[p+1];
529   int i,loop,delta;
530   int sign,swap,t;
531 
532   for (i = 0;i < p+1;++i) v[i] = 0;
533   for (i = 0;i < p+1;++i) r[i] = 0;
534   r[0] = 1;
535   for (i = 0;i < p;++i) f[i] = 0;
536   f[0] = 1; f[p-1] = f[p] = -1;
537   for (i = 0;i < p;++i) g[p-1-i] = in[i];
538   g[p] = 0;
539 
540   delta = 1;
541 
542   for (loop = 0;loop < 2*p-1;++loop) {
543     for (i = p;i > 0;--i) v[i] = v[i-1];
544     v[0] = 0;
545 
546     sign = -g[0]*f[0];
547     swap = int16_negative_mask(-delta) & int16_nonzero_mask(g[0]);
548     delta ^= swap&(delta^-delta);
549     delta += 1;
550 
551     for (i = 0;i < p+1;++i) {
552       t = swap&(f[i]^g[i]); f[i] ^= t; g[i] ^= t;
553       t = swap&(v[i]^r[i]); v[i] ^= t; r[i] ^= t;
554     }
555 
556     for (i = 0;i < p+1;++i) g[i] = F3_freeze(g[i]+sign*f[i]);
557     for (i = 0;i < p+1;++i) r[i] = F3_freeze(r[i]+sign*v[i]);
558 
559     for (i = 0;i < p;++i) g[i] = g[i+1];
560     g[p] = 0;
561   }
562 
563   sign = f[0];
564   for (i = 0;i < p;++i) out[i] = sign*v[p-1-i];
565 
566   return int16_nonzero_mask(delta);
567 }
568 
569 #endif
570 
571 /* ----- polynomials mod q */
572 
573 /* h = f*g in the ring Rq */
574 static void Rq_mult_small(Fq *h,const Fq *f,const small *g)
575 {
576   Fq fg[p+p-1];
577   Fq result;
578   int i,j;
579 
580   for (i = 0;i < p;++i) {
581     result = 0;
582     for (j = 0;j <= i;++j) result = Fq_freeze(result+f[j]*(int32)g[i-j]);
583     fg[i] = result;
584   }
585   for (i = p;i < p+p-1;++i) {
586     result = 0;
587     for (j = i-p+1;j < p;++j) result = Fq_freeze(result+f[j]*(int32)g[i-j]);
588     fg[i] = result;
589   }
590 
591   for (i = p+p-2;i >= p;--i) {
592     fg[i-p] = Fq_freeze(fg[i-p]+fg[i]);
593     fg[i-p+1] = Fq_freeze(fg[i-p+1]+fg[i]);
594   }
595 
596   for (i = 0;i < p;++i) h[i] = fg[i];
597 }
598 
599 #ifndef LPR
600 
601 /* h = 3f in Rq */
602 static void Rq_mult3(Fq *h,const Fq *f)
603 {
604   int i;
605 
606   for (i = 0;i < p;++i) h[i] = Fq_freeze(3*f[i]);
607 }
608 
609 /* out = 1/(3*in) in Rq */
610 /* returns 0 if recip succeeded; else -1 */
611 static int Rq_recip3(Fq *out,const small *in)
612 {
613   Fq f[p+1],g[p+1],v[p+1],r[p+1];
614   int i,loop,delta;
615   int swap,t;
616   int32 f0,g0;
617   Fq scale;
618 
619   for (i = 0;i < p+1;++i) v[i] = 0;
620   for (i = 0;i < p+1;++i) r[i] = 0;
621   r[0] = Fq_recip(3);
622   for (i = 0;i < p;++i) f[i] = 0;
623   f[0] = 1; f[p-1] = f[p] = -1;
624   for (i = 0;i < p;++i) g[p-1-i] = in[i];
625   g[p] = 0;
626 
627   delta = 1;
628 
629   for (loop = 0;loop < 2*p-1;++loop) {
630     for (i = p;i > 0;--i) v[i] = v[i-1];
631     v[0] = 0;
632 
633     swap = int16_negative_mask(-delta) & int16_nonzero_mask(g[0]);
634     delta ^= swap&(delta^-delta);
635     delta += 1;
636 
637     for (i = 0;i < p+1;++i) {
638       t = swap&(f[i]^g[i]); f[i] ^= t; g[i] ^= t;
639       t = swap&(v[i]^r[i]); v[i] ^= t; r[i] ^= t;
640     }
641 
642     f0 = f[0];
643     g0 = g[0];
644     for (i = 0;i < p+1;++i) g[i] = Fq_freeze(f0*g[i]-g0*f[i]);
645     for (i = 0;i < p+1;++i) r[i] = Fq_freeze(f0*r[i]-g0*v[i]);
646 
647     for (i = 0;i < p;++i) g[i] = g[i+1];
648     g[p] = 0;
649   }
650 
651   scale = Fq_recip(f[0]);
652   for (i = 0;i < p;++i) out[i] = Fq_freeze(scale*(int32)v[p-1-i]);
653 
654   return int16_nonzero_mask(delta);
655 }
656 
657 #endif
658 
659 /* ----- rounded polynomials mod q */
660 
661 static void Round(Fq *out,const Fq *a)
662 {
663   int i;
664   for (i = 0;i < p;++i) out[i] = a[i]-F3_freeze(a[i]);
665 }
666 
667 /* ----- sorting to generate short polynomial */
668 
669 static void Short_fromlist(small *out,const uint32 *in)
670 {
671   uint32 L[p];
672   int i;
673 
674   for (i = 0;i < w;++i) L[i] = in[i]&(uint32)-2;
675   for (i = w;i < p;++i) L[i] = (in[i]&(uint32)-3)|1;
676   crypto_sort_uint32(L,p);
677   for (i = 0;i < p;++i) out[i] = (L[i]&3)-1;
678 }
679 
680 /* ----- underlying hash function */
681 
682 #define Hash_bytes 32
683 
684 /* e.g., b = 0 means out = Hash0(in) */
685 static void Hash_prefix(unsigned char *out,int b,const unsigned char *in,int inlen)
686 {
687   unsigned char x[inlen+1];
688   unsigned char h[64];
689   int i;
690 
691   x[0] = b;
692   for (i = 0;i < inlen;++i) x[i+1] = in[i];
693   crypto_hash_sha512(h,x,inlen+1);
694   for (i = 0;i < 32;++i) out[i] = h[i];
695 }
696 
697 /* ----- higher-level randomness */
698 
699 static uint32 urandom32(void)
700 {
701   unsigned char c[4];
702   uint32 out[4];
703 
704   randombytes(c,4);
705   out[0] = (uint32)c[0];
706   out[1] = ((uint32)c[1])<<8;
707   out[2] = ((uint32)c[2])<<16;
708   out[3] = ((uint32)c[3])<<24;
709   return out[0]+out[1]+out[2]+out[3];
710 }
711 
712 static void Short_random(small *out)
713 {
714   uint32 L[p];
715   int i;
716 
717   for (i = 0;i < p;++i) L[i] = urandom32();
718   Short_fromlist(out,L);
719 }
720 
721 #ifndef LPR
722 
723 static void Small_random(small *out)
724 {
725   int i;
726 
727   for (i = 0;i < p;++i) out[i] = (((urandom32()&0x3fffffff)*3)>>30)-1;
728 }
729 
730 #endif
731 
732 /* ----- Streamlined NTRU Prime Core */
733 
734 #ifndef LPR
735 
736 /* h,(f,ginv) = KeyGen() */
737 static void KeyGen(Fq *h,small *f,small *ginv)
738 {
739   small g[p];
740   Fq finv[p];
741 
742   for (;;) {
743     Small_random(g);
744     if (R3_recip(ginv,g) == 0) break;
745   }
746   Short_random(f);
747   Rq_recip3(finv,f); /* always works */
748   Rq_mult_small(h,finv,g);
749 }
750 
751 /* c = Encrypt(r,h) */
752 static void Encrypt(Fq *c,const small *r,const Fq *h)
753 {
754   Fq hr[p];
755 
756   Rq_mult_small(hr,h,r);
757   Round(c,hr);
758 }
759 
760 /* r = Decrypt(c,(f,ginv)) */
761 static void Decrypt(small *r,const Fq *c,const small *f,const small *ginv)
762 {
763   Fq cf[p];
764   Fq cf3[p];
765   small e[p];
766   small ev[p];
767   int mask;
768   int i;
769 
770   Rq_mult_small(cf,c,f);
771   Rq_mult3(cf3,cf);
772   R3_fromRq(e,cf3);
773   R3_mult(ev,e,ginv);
774 
775   mask = Weightw_mask(ev); /* 0 if weight w, else -1 */
776   for (i = 0;i < w;++i) r[i] = ((ev[i]^1)&~mask)^1;
777   for (i = w;i < p;++i) r[i] = ev[i]&~mask;
778 }
779 
780 #endif
781 
782 /* ----- NTRU LPRime Core */
783 
784 #ifdef LPR
785 
786 /* (G,A),a = KeyGen(G); leaves G unchanged */
787 static void KeyGen(Fq *A,small *a,const Fq *G)
788 {
789   Fq aG[p];
790 
791   Short_random(a);
792   Rq_mult_small(aG,G,a);
793   Round(A,aG);
794 }
795 
796 /* B,T = Encrypt(r,(G,A),b) */
797 static void Encrypt(Fq *B,int8 *T,const int8 *r,const Fq *G,const Fq *A,const small *b)
798 {
799   Fq bG[p];
800   Fq bA[p];
801   int i;
802 
803   Rq_mult_small(bG,G,b);
804   Round(B,bG);
805   Rq_mult_small(bA,A,b);
806   for (i = 0;i < I;++i) T[i] = Top(Fq_freeze(bA[i]+r[i]*q12));
807 }
808 
809 /* r = Decrypt((B,T),a) */
810 static void Decrypt(int8 *r,const Fq *B,const int8 *T,const small *a)
811 {
812   Fq aB[p];
813   int i;
814 
815   Rq_mult_small(aB,B,a);
816   for (i = 0;i < I;++i)
817     r[i] = -int16_negative_mask(Fq_freeze(Right(T[i])-aB[i]+4*w+1));
818 }
819 
820 #endif
821 
822 /* ----- encoding I-bit inputs */
823 
824 #ifdef LPR
825 
826 #define Inputs_bytes (I/8)
827 typedef int8 Inputs[I]; /* passed by reference */
828 
829 static void Inputs_encode(unsigned char *s,const Inputs r)
830 {
831   int i;
832   for (i = 0;i < Inputs_bytes;++i) s[i] = 0;
833   for (i = 0;i < I;++i) s[i>>3] |= r[i]<<(i&7);
834 }
835 
836 #endif
837 
838 /* ----- Expand */
839 
840 #ifdef LPR
841 
842 static const unsigned char aes_nonce[16] = {0};
843 
844 static void Expand(uint32 *L,const unsigned char *k)
845 {
846   int i;
847   crypto_stream_aes256ctr((unsigned char *) L,4*p,aes_nonce,k);
848   for (i = 0;i < p;++i) {
849     uint32 L0 = ((unsigned char *) L)[4*i];
850     uint32 L1 = ((unsigned char *) L)[4*i+1];
851     uint32 L2 = ((unsigned char *) L)[4*i+2];
852     uint32 L3 = ((unsigned char *) L)[4*i+3];
853     L[i] = L0+(L1<<8)+(L2<<16)+(L3<<24);
854   }
855 }
856 
857 #endif
858 
859 /* ----- Seeds */
860 
861 #ifdef LPR
862 
863 #define Seeds_bytes 32
864 
865 static void Seeds_random(unsigned char *s)
866 {
867   randombytes(s,Seeds_bytes);
868 }
869 
870 #endif
871 
872 /* ----- Generator, HashShort */
873 
874 #ifdef LPR
875 
876 /* G = Generator(k) */
877 static void Generator(Fq *G,const unsigned char *k)
878 {
879   uint32 L[p];
880   int i;
881 
882   Expand(L,k);
883   for (i = 0;i < p;++i) G[i] = uint32_mod_uint14(L[i],q)-q12;
884 }
885 
886 /* out = HashShort(r) */
887 static void HashShort(small *out,const Inputs r)
888 {
889   unsigned char s[Inputs_bytes];
890   unsigned char h[Hash_bytes];
891   uint32 L[p];
892 
893   Inputs_encode(s,r);
894   Hash_prefix(h,5,s,sizeof s);
895   Expand(L,h);
896   Short_fromlist(out,L);
897 }
898 
899 #endif
900 
901 /* ----- NTRU LPRime Expand */
902 
903 #ifdef LPR
904 
905 /* (S,A),a = XKeyGen() */
906 static void XKeyGen(unsigned char *S,Fq *A,small *a)
907 {
908   Fq G[p];
909 
910   Seeds_random(S);
911   Generator(G,S);
912   KeyGen(A,a,G);
913 }
914 
915 /* B,T = XEncrypt(r,(S,A)) */
916 static void XEncrypt(Fq *B,int8 *T,const int8 *r,const unsigned char *S,const Fq *A)
917 {
918   Fq G[p];
919   small b[p];
920 
921   Generator(G,S);
922   HashShort(b,r);
923   Encrypt(B,T,r,G,A,b);
924 }
925 
926 #define XDecrypt Decrypt
927 
928 #endif
929 
930 /* ----- encoding small polynomials (including short polynomials) */
931 
932 #define Small_bytes ((p+3)/4)
933 
934 /* these are the only functions that rely on p mod 4 = 1 */
935 
936 static void Small_encode(unsigned char *s,const small *f)
937 {
938   small x;
939   int i;
940 
941   for (i = 0;i < p/4;++i) {
942     x = *f++ + 1;
943     x += (*f++ + 1)<<2;
944     x += (*f++ + 1)<<4;
945     x += (*f++ + 1)<<6;
946     *s++ = x;
947   }
948   x = *f++ + 1;
949   *s++ = x;
950 }
951 
952 static void Small_decode(small *f,const unsigned char *s)
953 {
954   unsigned char x;
955   int i;
956 
957   for (i = 0;i < p/4;++i) {
958     x = *s++;
959     *f++ = ((small)(x&3))-1; x >>= 2;
960     *f++ = ((small)(x&3))-1; x >>= 2;
961     *f++ = ((small)(x&3))-1; x >>= 2;
962     *f++ = ((small)(x&3))-1;
963   }
964   x = *s++;
965   *f++ = ((small)(x&3))-1;
966 }
967 
968 /* ----- encoding general polynomials */
969 
970 #ifndef LPR
971 
972 static void Rq_encode(unsigned char *s,const Fq *r)
973 {
974   uint16 R[p],M[p];
975   int i;
976 
977   for (i = 0;i < p;++i) R[i] = r[i]+q12;
978   for (i = 0;i < p;++i) M[i] = q;
979   Encode(s,R,M,p);
980 }
981 
982 static void Rq_decode(Fq *r,const unsigned char *s)
983 {
984   uint16 R[p],M[p];
985   int i;
986 
987   for (i = 0;i < p;++i) M[i] = q;
988   Decode(R,s,M,p);
989   for (i = 0;i < p;++i) r[i] = ((Fq)R[i])-q12;
990 }
991 
992 #endif
993 
994 /* ----- encoding rounded polynomials */
995 
996 static void Rounded_encode(unsigned char *s,const Fq *r)
997 {
998   uint16 R[p],M[p];
999   int i;
1000 
1001   for (i = 0;i < p;++i) R[i] = ((r[i]+q12)*10923)>>15;
1002   for (i = 0;i < p;++i) M[i] = (q+2)/3;
1003   Encode(s,R,M,p);
1004 }
1005 
1006 static void Rounded_decode(Fq *r,const unsigned char *s)
1007 {
1008   uint16 R[p],M[p];
1009   int i;
1010 
1011   for (i = 0;i < p;++i) M[i] = (q+2)/3;
1012   Decode(R,s,M,p);
1013   for (i = 0;i < p;++i) r[i] = R[i]*3-q12;
1014 }
1015 
1016 /* ----- encoding top polynomials */
1017 
1018 #ifdef LPR
1019 
1020 #define Top_bytes (I/2)
1021 
1022 static void Top_encode(unsigned char *s,const int8 *T)
1023 {
1024   int i;
1025   for (i = 0;i < Top_bytes;++i)
1026     s[i] = T[2*i]+(T[2*i+1]<<4);
1027 }
1028 
1029 static void Top_decode(int8 *T,const unsigned char *s)
1030 {
1031   int i;
1032   for (i = 0;i < Top_bytes;++i) {
1033     T[2*i] = s[i]&15;
1034     T[2*i+1] = s[i]>>4;
1035   }
1036 }
1037 
1038 #endif
1039 
1040 /* ----- Streamlined NTRU Prime Core plus encoding */
1041 
1042 #ifndef LPR
1043 
1044 typedef small Inputs[p]; /* passed by reference */
1045 #define Inputs_random Short_random
1046 #define Inputs_encode Small_encode
1047 #define Inputs_bytes Small_bytes
1048 
1049 #define Ciphertexts_bytes Rounded_bytes
1050 #define SecretKeys_bytes (2*Small_bytes)
1051 #define PublicKeys_bytes Rq_bytes
1052 
1053 /* pk,sk = ZKeyGen() */
1054 static void ZKeyGen(unsigned char *pk,unsigned char *sk)
1055 {
1056   Fq h[p];
1057   small f[p],v[p];
1058 
1059   KeyGen(h,f,v);
1060   Rq_encode(pk,h);
1061   Small_encode(sk,f); sk += Small_bytes;
1062   Small_encode(sk,v);
1063 }
1064 
1065 /* C = ZEncrypt(r,pk) */
1066 static void ZEncrypt(unsigned char *C,const Inputs r,const unsigned char *pk)
1067 {
1068   Fq h[p];
1069   Fq c[p];
1070   Rq_decode(h,pk);
1071   Encrypt(c,r,h);
1072   Rounded_encode(C,c);
1073 }
1074 
1075 /* r = ZDecrypt(C,sk) */
1076 static void ZDecrypt(Inputs r,const unsigned char *C,const unsigned char *sk)
1077 {
1078   small f[p],v[p];
1079   Fq c[p];
1080 
1081   Small_decode(f,sk); sk += Small_bytes;
1082   Small_decode(v,sk);
1083   Rounded_decode(c,C);
1084   Decrypt(r,c,f,v);
1085 }
1086 
1087 #endif
1088 
1089 /* ----- NTRU LPRime Expand plus encoding */
1090 
1091 #ifdef LPR
1092 
1093 #define Ciphertexts_bytes (Rounded_bytes+Top_bytes)
1094 #define SecretKeys_bytes Small_bytes
1095 #define PublicKeys_bytes (Seeds_bytes+Rounded_bytes)
1096 
1097 static void Inputs_random(Inputs r)
1098 {
1099   unsigned char s[Inputs_bytes];
1100   int i;
1101 
1102   randombytes(s,sizeof s);
1103   for (i = 0;i < I;++i) r[i] = 1&(s[i>>3]>>(i&7));
1104 }
1105 
1106 /* pk,sk = ZKeyGen() */
1107 static void ZKeyGen(unsigned char *pk,unsigned char *sk)
1108 {
1109   Fq A[p];
1110   small a[p];
1111 
1112   XKeyGen(pk,A,a); pk += Seeds_bytes;
1113   Rounded_encode(pk,A);
1114   Small_encode(sk,a);
1115 }
1116 
1117 /* c = ZEncrypt(r,pk) */
1118 static void ZEncrypt(unsigned char *c,const Inputs r,const unsigned char *pk)
1119 {
1120   Fq A[p];
1121   Fq B[p];
1122   int8 T[I];
1123 
1124   Rounded_decode(A,pk+Seeds_bytes);
1125   XEncrypt(B,T,r,pk,A);
1126   Rounded_encode(c,B); c += Rounded_bytes;
1127   Top_encode(c,T);
1128 }
1129 
1130 /* r = ZDecrypt(C,sk) */
1131 static void ZDecrypt(Inputs r,const unsigned char *c,const unsigned char *sk)
1132 {
1133   small a[p];
1134   Fq B[p];
1135   int8 T[I];
1136 
1137   Small_decode(a,sk);
1138   Rounded_decode(B,c);
1139   Top_decode(T,c+Rounded_bytes);
1140   XDecrypt(r,B,T,a);
1141 }
1142 
1143 #endif
1144 
1145 /* ----- confirmation hash */
1146 
1147 #define Confirm_bytes 32
1148 
1149 /* h = HashConfirm(r,pk,cache); cache is Hash4(pk) */
1150 static void HashConfirm(unsigned char *h,const unsigned char *r,const unsigned char *pk,const unsigned char *cache)
1151 {
1152 #ifndef LPR
1153   unsigned char x[Hash_bytes*2];
1154   int i;
1155 
1156   Hash_prefix(x,3,r,Inputs_bytes);
1157   for (i = 0;i < Hash_bytes;++i) x[Hash_bytes+i] = cache[i];
1158 #else
1159   unsigned char x[Inputs_bytes+Hash_bytes];
1160   int i;
1161 
1162   for (i = 0;i < Inputs_bytes;++i) x[i] = r[i];
1163   for (i = 0;i < Hash_bytes;++i) x[Inputs_bytes+i] = cache[i];
1164 #endif
1165   Hash_prefix(h,2,x,sizeof x);
1166 }
1167 
1168 /* ----- session-key hash */
1169 
1170 /* k = HashSession(b,y,z) */
1171 static void HashSession(unsigned char *k,int b,const unsigned char *y,const unsigned char *z)
1172 {
1173 #ifndef LPR
1174   unsigned char x[Hash_bytes+Ciphertexts_bytes+Confirm_bytes];
1175   int i;
1176 
1177   Hash_prefix(x,3,y,Inputs_bytes);
1178   for (i = 0;i < Ciphertexts_bytes+Confirm_bytes;++i) x[Hash_bytes+i] = z[i];
1179 #else
1180   unsigned char x[Inputs_bytes+Ciphertexts_bytes+Confirm_bytes];
1181   int i;
1182 
1183   for (i = 0;i < Inputs_bytes;++i) x[i] = y[i];
1184   for (i = 0;i < Ciphertexts_bytes+Confirm_bytes;++i) x[Inputs_bytes+i] = z[i];
1185 #endif
1186   Hash_prefix(k,b,x,sizeof x);
1187 }
1188 
1189 /* ----- Streamlined NTRU Prime and NTRU LPRime */
1190 
1191 /* pk,sk = KEM_KeyGen() */
1192 static void KEM_KeyGen(unsigned char *pk,unsigned char *sk)
1193 {
1194   int i;
1195 
1196   ZKeyGen(pk,sk); sk += SecretKeys_bytes;
1197   for (i = 0;i < PublicKeys_bytes;++i) *sk++ = pk[i];
1198   randombytes(sk,Inputs_bytes); sk += Inputs_bytes;
1199   Hash_prefix(sk,4,pk,PublicKeys_bytes);
1200 }
1201 
1202 /* c,r_enc = Hide(r,pk,cache); cache is Hash4(pk) */
1203 static void Hide(unsigned char *c,unsigned char *r_enc,const Inputs r,const unsigned char *pk,const unsigned char *cache)
1204 {
1205   Inputs_encode(r_enc,r);
1206   ZEncrypt(c,r,pk); c += Ciphertexts_bytes;
1207   HashConfirm(c,r_enc,pk,cache);
1208 }
1209 
1210 /* c,k = Encap(pk) */
1211 static void Encap(unsigned char *c,unsigned char *k,const unsigned char *pk)
1212 {
1213   Inputs r;
1214   unsigned char r_enc[Inputs_bytes];
1215   unsigned char cache[Hash_bytes];
1216 
1217   Hash_prefix(cache,4,pk,PublicKeys_bytes);
1218   Inputs_random(r);
1219   Hide(c,r_enc,r,pk,cache);
1220   HashSession(k,1,r_enc,c);
1221 }
1222 
1223 /* 0 if matching ciphertext+confirm, else -1 */
1224 static int Ciphertexts_diff_mask(const unsigned char *c,const unsigned char *c2)
1225 {
1226   uint16 differentbits = 0;
1227   int len = Ciphertexts_bytes+Confirm_bytes;
1228 
1229   while (len-- > 0) differentbits |= (*c++)^(*c2++);
1230   return (1&((differentbits-1)>>8))-1;
1231 }
1232 
1233 /* k = Decap(c,sk) */
1234 static void Decap(unsigned char *k,const unsigned char *c,const unsigned char *sk)
1235 {
1236   const unsigned char *pk = sk + SecretKeys_bytes;
1237   const unsigned char *rho = pk + PublicKeys_bytes;
1238   const unsigned char *cache = rho + Inputs_bytes;
1239   Inputs r;
1240   unsigned char r_enc[Inputs_bytes];
1241   unsigned char cnew[Ciphertexts_bytes+Confirm_bytes];
1242   int mask;
1243   int i;
1244 
1245   ZDecrypt(r,c,sk);
1246   Hide(cnew,r_enc,r,pk,cache);
1247   mask = Ciphertexts_diff_mask(c,cnew);
1248   for (i = 0;i < Inputs_bytes;++i) r_enc[i] ^= mask&(r_enc[i]^rho[i]);
1249   HashSession(k,1+mask,r_enc,c);
1250 }
1251 
1252 /* ----- crypto_kem API */
1253 
1254 
1255 int crypto_kem_sntrup761_keypair(unsigned char *pk,unsigned char *sk)
1256 {
1257   KEM_KeyGen(pk,sk);
1258   return 0;
1259 }
1260 
1261 int crypto_kem_sntrup761_enc(unsigned char *c,unsigned char *k,const unsigned char *pk)
1262 {
1263   Encap(c,k,pk);
1264   return 0;
1265 }
1266 
1267 int crypto_kem_sntrup761_dec(unsigned char *k,const unsigned char *c,const unsigned char *sk)
1268 {
1269   Decap(k,c,sk);
1270   return 0;
1271 }
1272 
1273