xref: /openbsd-src/usr.bin/ssh/chacha.c (revision 254cc503a1fee50f52e3da5adb3493b1cf93ddb9)
1 /* $OpenBSD: chacha.c,v 1.2 2023/07/17 05:26:38 djm Exp $ */
2 /*
3 chacha-merged.c version 20080118
4 D. J. Bernstein
5 Public domain.
6 */
7 
8 #include "chacha.h"
9 
10 typedef unsigned char u8;
11 typedef unsigned int u32;
12 
13 typedef struct chacha_ctx chacha_ctx;
14 
15 #define U8C(v) (v##U)
16 #define U32C(v) (v##U)
17 
18 #define U8V(v) ((u8)(v) & U8C(0xFF))
19 #define U32V(v) ((u32)(v) & U32C(0xFFFFFFFF))
20 
21 #define ROTL32(v, n) \
22   (U32V((v) << (n)) | ((v) >> (32 - (n))))
23 
24 #define U8TO32_LITTLE(p) \
25   (((u32)((p)[0])      ) | \
26    ((u32)((p)[1]) <<  8) | \
27    ((u32)((p)[2]) << 16) | \
28    ((u32)((p)[3]) << 24))
29 
30 #define U32TO8_LITTLE(p, v) \
31   do { \
32     (p)[0] = U8V((v)      ); \
33     (p)[1] = U8V((v) >>  8); \
34     (p)[2] = U8V((v) >> 16); \
35     (p)[3] = U8V((v) >> 24); \
36   } while (0)
37 
38 #define ROTATE(v,c) (ROTL32(v,c))
39 #define XOR(v,w) ((v) ^ (w))
40 #define PLUS(v,w) (U32V((v) + (w)))
41 #define PLUSONE(v) (PLUS((v),1))
42 
43 #define QUARTERROUND(a,b,c,d) \
44   a = PLUS(a,b); d = ROTATE(XOR(d,a),16); \
45   c = PLUS(c,d); b = ROTATE(XOR(b,c),12); \
46   a = PLUS(a,b); d = ROTATE(XOR(d,a), 8); \
47   c = PLUS(c,d); b = ROTATE(XOR(b,c), 7);
48 
49 static const char sigma[16] = "expand 32-byte k";
50 static const char tau[16] = "expand 16-byte k";
51 
52 void
chacha_keysetup(chacha_ctx * x,const u8 * k,u32 kbits)53 chacha_keysetup(chacha_ctx *x,const u8 *k,u32 kbits)
54 {
55   const char *constants;
56 
57   x->input[4] = U8TO32_LITTLE(k + 0);
58   x->input[5] = U8TO32_LITTLE(k + 4);
59   x->input[6] = U8TO32_LITTLE(k + 8);
60   x->input[7] = U8TO32_LITTLE(k + 12);
61   if (kbits == 256) { /* recommended */
62     k += 16;
63     constants = sigma;
64   } else { /* kbits == 128 */
65     constants = tau;
66   }
67   x->input[8] = U8TO32_LITTLE(k + 0);
68   x->input[9] = U8TO32_LITTLE(k + 4);
69   x->input[10] = U8TO32_LITTLE(k + 8);
70   x->input[11] = U8TO32_LITTLE(k + 12);
71   x->input[0] = U8TO32_LITTLE(constants + 0);
72   x->input[1] = U8TO32_LITTLE(constants + 4);
73   x->input[2] = U8TO32_LITTLE(constants + 8);
74   x->input[3] = U8TO32_LITTLE(constants + 12);
75 }
76 
77 void
chacha_ivsetup(chacha_ctx * x,const u8 * iv,const u8 * counter)78 chacha_ivsetup(chacha_ctx *x, const u8 *iv, const u8 *counter)
79 {
80   x->input[12] = counter == NULL ? 0 : U8TO32_LITTLE(counter + 0);
81   x->input[13] = counter == NULL ? 0 : U8TO32_LITTLE(counter + 4);
82   x->input[14] = U8TO32_LITTLE(iv + 0);
83   x->input[15] = U8TO32_LITTLE(iv + 4);
84 }
85 
86 void
chacha_encrypt_bytes(chacha_ctx * x,const u8 * m,u8 * c,u32 bytes)87 chacha_encrypt_bytes(chacha_ctx *x,const u8 *m,u8 *c,u32 bytes)
88 {
89   u32 x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15;
90   u32 j0, j1, j2, j3, j4, j5, j6, j7, j8, j9, j10, j11, j12, j13, j14, j15;
91   u8 *ctarget = NULL;
92   u8 tmp[64];
93   u_int i;
94 
95   if (!bytes) return;
96 
97   j0 = x->input[0];
98   j1 = x->input[1];
99   j2 = x->input[2];
100   j3 = x->input[3];
101   j4 = x->input[4];
102   j5 = x->input[5];
103   j6 = x->input[6];
104   j7 = x->input[7];
105   j8 = x->input[8];
106   j9 = x->input[9];
107   j10 = x->input[10];
108   j11 = x->input[11];
109   j12 = x->input[12];
110   j13 = x->input[13];
111   j14 = x->input[14];
112   j15 = x->input[15];
113 
114   for (;;) {
115     if (bytes < 64) {
116       for (i = 0;i < bytes;++i) tmp[i] = m[i];
117       m = tmp;
118       ctarget = c;
119       c = tmp;
120     }
121     x0 = j0;
122     x1 = j1;
123     x2 = j2;
124     x3 = j3;
125     x4 = j4;
126     x5 = j5;
127     x6 = j6;
128     x7 = j7;
129     x8 = j8;
130     x9 = j9;
131     x10 = j10;
132     x11 = j11;
133     x12 = j12;
134     x13 = j13;
135     x14 = j14;
136     x15 = j15;
137     for (i = 20;i > 0;i -= 2) {
138       QUARTERROUND( x0, x4, x8,x12)
139       QUARTERROUND( x1, x5, x9,x13)
140       QUARTERROUND( x2, x6,x10,x14)
141       QUARTERROUND( x3, x7,x11,x15)
142       QUARTERROUND( x0, x5,x10,x15)
143       QUARTERROUND( x1, x6,x11,x12)
144       QUARTERROUND( x2, x7, x8,x13)
145       QUARTERROUND( x3, x4, x9,x14)
146     }
147     x0 = PLUS(x0,j0);
148     x1 = PLUS(x1,j1);
149     x2 = PLUS(x2,j2);
150     x3 = PLUS(x3,j3);
151     x4 = PLUS(x4,j4);
152     x5 = PLUS(x5,j5);
153     x6 = PLUS(x6,j6);
154     x7 = PLUS(x7,j7);
155     x8 = PLUS(x8,j8);
156     x9 = PLUS(x9,j9);
157     x10 = PLUS(x10,j10);
158     x11 = PLUS(x11,j11);
159     x12 = PLUS(x12,j12);
160     x13 = PLUS(x13,j13);
161     x14 = PLUS(x14,j14);
162     x15 = PLUS(x15,j15);
163 
164     x0 = XOR(x0,U8TO32_LITTLE(m + 0));
165     x1 = XOR(x1,U8TO32_LITTLE(m + 4));
166     x2 = XOR(x2,U8TO32_LITTLE(m + 8));
167     x3 = XOR(x3,U8TO32_LITTLE(m + 12));
168     x4 = XOR(x4,U8TO32_LITTLE(m + 16));
169     x5 = XOR(x5,U8TO32_LITTLE(m + 20));
170     x6 = XOR(x6,U8TO32_LITTLE(m + 24));
171     x7 = XOR(x7,U8TO32_LITTLE(m + 28));
172     x8 = XOR(x8,U8TO32_LITTLE(m + 32));
173     x9 = XOR(x9,U8TO32_LITTLE(m + 36));
174     x10 = XOR(x10,U8TO32_LITTLE(m + 40));
175     x11 = XOR(x11,U8TO32_LITTLE(m + 44));
176     x12 = XOR(x12,U8TO32_LITTLE(m + 48));
177     x13 = XOR(x13,U8TO32_LITTLE(m + 52));
178     x14 = XOR(x14,U8TO32_LITTLE(m + 56));
179     x15 = XOR(x15,U8TO32_LITTLE(m + 60));
180 
181     j12 = PLUSONE(j12);
182     if (!j12) {
183       j13 = PLUSONE(j13);
184       /* stopping at 2^70 bytes per nonce is user's responsibility */
185     }
186 
187     U32TO8_LITTLE(c + 0,x0);
188     U32TO8_LITTLE(c + 4,x1);
189     U32TO8_LITTLE(c + 8,x2);
190     U32TO8_LITTLE(c + 12,x3);
191     U32TO8_LITTLE(c + 16,x4);
192     U32TO8_LITTLE(c + 20,x5);
193     U32TO8_LITTLE(c + 24,x6);
194     U32TO8_LITTLE(c + 28,x7);
195     U32TO8_LITTLE(c + 32,x8);
196     U32TO8_LITTLE(c + 36,x9);
197     U32TO8_LITTLE(c + 40,x10);
198     U32TO8_LITTLE(c + 44,x11);
199     U32TO8_LITTLE(c + 48,x12);
200     U32TO8_LITTLE(c + 52,x13);
201     U32TO8_LITTLE(c + 56,x14);
202     U32TO8_LITTLE(c + 60,x15);
203 
204     if (bytes <= 64) {
205       if (bytes < 64) {
206         for (i = 0;i < bytes;++i) ctarget[i] = c[i];
207       }
208       x->input[12] = j12;
209       x->input[13] = j13;
210       return;
211     }
212     bytes -= 64;
213     c += 64;
214     m += 64;
215   }
216 }
217