1 /*
2 * RSA authentication.
3 *
4 * Old ssh client protocol:
5 * read public key
6 * if you don't like it, read another, repeat
7 * write challenge
8 * read response
9 *
10 * all numbers are hexadecimal biginits parsable with strtomp.
11 *
12 * Sign (PKCS #1 using hash=sha1 or hash=md5)
13 * write hash(msg)
14 * read signature(hash(msg))
15 *
16 * Verify:
17 * write hash(msg)
18 * write signature(hash(msg))
19 * read ok or fail
20 */
21
22 #include "dat.h"
23
24 enum {
25 CHavePub,
26 CHaveResp,
27 VNeedHash,
28 VNeedSig,
29 VHaveResp,
30 SNeedHash,
31 SHaveResp,
32 Maxphase,
33 };
34
35 static char *phasenames[] = {
36 [CHavePub] "CHavePub",
37 [CHaveResp] "CHaveResp",
38 [VNeedHash] "VNeedHash",
39 [VNeedSig] "VNeedSig",
40 [VHaveResp] "VHaveResp",
41 [SNeedHash] "SNeedHash",
42 [SHaveResp] "SHaveResp",
43 };
44
45 struct State
46 {
47 RSApriv *priv;
48 mpint *resp;
49 int off;
50 Key *key;
51 mpint *digest;
52 int sigresp;
53 };
54
55 static mpint* mkdigest(RSApub *key, char *hashalg, uchar *hash, uint dlen);
56
57 static RSApriv*
readrsapriv(Key * k)58 readrsapriv(Key *k)
59 {
60 char *a;
61 RSApriv *priv;
62
63 priv = rsaprivalloc();
64
65 if((a=_strfindattr(k->attr, "ek"))==nil || (priv->pub.ek=strtomp(a, nil, 16, nil))==nil)
66 goto Error;
67 if((a=_strfindattr(k->attr, "n"))==nil || (priv->pub.n=strtomp(a, nil, 16, nil))==nil)
68 goto Error;
69 if(k->privattr == nil) /* only public half */
70 return priv;
71 if((a=_strfindattr(k->privattr, "!p"))==nil || (priv->p=strtomp(a, nil, 16, nil))==nil)
72 goto Error;
73 if((a=_strfindattr(k->privattr, "!q"))==nil || (priv->q=strtomp(a, nil, 16, nil))==nil)
74 goto Error;
75 if((a=_strfindattr(k->privattr, "!kp"))==nil || (priv->kp=strtomp(a, nil, 16, nil))==nil)
76 goto Error;
77 if((a=_strfindattr(k->privattr, "!kq"))==nil || (priv->kq=strtomp(a, nil, 16, nil))==nil)
78 goto Error;
79 if((a=_strfindattr(k->privattr, "!c2"))==nil || (priv->c2=strtomp(a, nil, 16, nil))==nil)
80 goto Error;
81 if((a=_strfindattr(k->privattr, "!dk"))==nil || (priv->dk=strtomp(a, nil, 16, nil))==nil)
82 goto Error;
83 return priv;
84
85 Error:
86 rsaprivfree(priv);
87 return nil;
88 }
89
90 static int
rsainit(Proto *,Fsstate * fss)91 rsainit(Proto*, Fsstate *fss)
92 {
93 Keyinfo ki;
94 State *s;
95 char *role;
96
97 if((role = _strfindattr(fss->attr, "role")) == nil)
98 return failure(fss, "rsa role not specified");
99 if(strcmp(role, "client") == 0)
100 fss->phase = CHavePub;
101 else if(strcmp(role, "sign") == 0)
102 fss->phase = SNeedHash;
103 else if(strcmp(role, "verify") == 0)
104 fss->phase = VNeedHash;
105 else
106 return failure(fss, "rsa role %s unimplemented", role);
107
108 s = emalloc(sizeof *s);
109 fss->phasename = phasenames;
110 fss->maxphase = Maxphase;
111 fss->ps = s;
112
113 switch(fss->phase){
114 case SNeedHash:
115 case VNeedHash:
116 mkkeyinfo(&ki, fss, nil);
117 if(findkey(&s->key, &ki, nil) != RpcOk)
118 return failure(fss, nil);
119 /* signing needs private key */
120 if(fss->phase == SNeedHash && s->key->privattr == nil)
121 return failure(fss,
122 "missing private half of key -- cannot sign");
123 }
124 return RpcOk;
125 }
126
127 static int
rsaread(Fsstate * fss,void * va,uint * n)128 rsaread(Fsstate *fss, void *va, uint *n)
129 {
130 RSApriv *priv;
131 State *s;
132 mpint *m;
133 Keyinfo ki;
134 int len, r;
135
136 s = fss->ps;
137 switch(fss->phase){
138 default:
139 return phaseerror(fss, "read");
140 case CHavePub:
141 if(s->key){
142 closekey(s->key);
143 s->key = nil;
144 }
145 mkkeyinfo(&ki, fss, nil);
146 ki.skip = s->off;
147 ki.noconf = 1;
148 if(findkey(&s->key, &ki, nil) != RpcOk)
149 return failure(fss, nil);
150 s->off++;
151 priv = s->key->priv;
152 *n = snprint(va, *n, "%B", priv->pub.n);
153 return RpcOk;
154 case CHaveResp:
155 *n = snprint(va, *n, "%B", s->resp);
156 fss->phase = Established;
157 return RpcOk;
158 case SHaveResp:
159 priv = s->key->priv;
160 len = (mpsignif(priv->pub.n)+7)/8;
161 if(len > *n)
162 return failure(fss, "signature buffer too short");
163 m = rsadecrypt(priv, s->digest, nil);
164 r = mptobe(m, (uchar*)va, len, nil);
165 if(r < len){
166 memmove((uchar*)va+len-r, va, r);
167 memset(va, 0, len-r);
168 }
169 *n = len;
170 mpfree(m);
171 fss->phase = Established;
172 return RpcOk;
173 case VHaveResp:
174 *n = snprint(va, *n, "%s", s->sigresp == 0? "ok":
175 "signature does not verify");
176 fss->phase = Established;
177 return RpcOk;
178 }
179 }
180
181 static int
rsawrite(Fsstate * fss,void * va,uint n)182 rsawrite(Fsstate *fss, void *va, uint n)
183 {
184 RSApriv *priv;
185 mpint *m, *mm;
186 State *s;
187 char *hash;
188 int dlen;
189
190 s = fss->ps;
191 switch(fss->phase){
192 default:
193 return phaseerror(fss, "write");
194 case CHavePub:
195 if(s->key == nil)
196 return failure(fss, "no current key");
197 switch(canusekey(fss, s->key)){
198 case -1:
199 return RpcConfirm;
200 case 0:
201 return failure(fss, "confirmation denied");
202 case 1:
203 break;
204 }
205 m = strtomp(va, nil, 16, nil);
206 if(m == nil)
207 return failure(fss, "invalid challenge value");
208 m = rsadecrypt(s->key->priv, m, m);
209 s->resp = m;
210 fss->phase = CHaveResp;
211 return RpcOk;
212 case SNeedHash:
213 case VNeedHash:
214 /* get hash type from key */
215 hash = _strfindattr(s->key->attr, "hash");
216 if(hash == nil)
217 hash = "sha1";
218 if(strcmp(hash, "sha1") == 0)
219 dlen = SHA1dlen;
220 else if(strcmp(hash, "md5") == 0)
221 dlen = MD5dlen;
222 else
223 return failure(fss, "unknown hash function %s", hash);
224 if(n != dlen)
225 return failure(fss, "hash length %d should be %d",
226 n, dlen);
227 priv = s->key->priv;
228 s->digest = mkdigest(&priv->pub, hash, (uchar *)va, n);
229 if(s->digest == nil)
230 return failure(fss, nil);
231 if(fss->phase == VNeedHash)
232 fss->phase = VNeedSig;
233 else
234 fss->phase = SHaveResp;
235 return RpcOk;
236 case VNeedSig:
237 priv = s->key->priv;
238 m = betomp((uchar*)va, n, nil);
239 mm = rsaencrypt(&priv->pub, m, nil);
240 s->sigresp = mpcmp(s->digest, mm);
241 mpfree(m);
242 mpfree(mm);
243 fss->phase = VHaveResp;
244 return RpcOk;
245 }
246 }
247
248 static void
rsaclose(Fsstate * fss)249 rsaclose(Fsstate *fss)
250 {
251 State *s;
252
253 s = fss->ps;
254 if(s->key)
255 closekey(s->key);
256 if(s->resp)
257 mpfree(s->resp);
258 if(s->digest)
259 mpfree(s->digest);
260 free(s);
261 }
262
263 static int
rsaaddkey(Key * k,int before)264 rsaaddkey(Key *k, int before)
265 {
266 fmtinstall('B', mpfmt);
267
268 if((k->priv = readrsapriv(k)) == nil){
269 werrstr("malformed key data");
270 return -1;
271 }
272 return replacekey(k, before);
273 }
274
275 static void
rsaclosekey(Key * k)276 rsaclosekey(Key *k)
277 {
278 rsaprivfree(k->priv);
279 }
280
281 Proto rsa = {
282 .name= "rsa",
283 .init= rsainit,
284 .write= rsawrite,
285 .read= rsaread,
286 .close= rsaclose,
287 .addkey= rsaaddkey,
288 .closekey= rsaclosekey,
289 };
290
291 /*
292 * Simple ASN.1 encodings.
293 * Lengths < 128 are encoded as 1-bytes constants,
294 * making our life easy.
295 */
296
297 /*
298 * Hash OIDs
299 *
300 * SHA1 = 1.3.14.3.2.26
301 * MDx = 1.2.840.113549.2.x
302 */
303 #define O0(a,b) ((a)*40+(b))
304 #define O2(x) \
305 (((x)>> 7)&0x7F)|0x80, \
306 ((x)&0x7F)
307 #define O3(x) \
308 (((x)>>14)&0x7F)|0x80, \
309 (((x)>> 7)&0x7F)|0x80, \
310 ((x)&0x7F)
311 uchar oidsha1[] = { O0(1, 3), 14, 3, 2, 26 };
312 uchar oidmd2[] = { O0(1, 2), O2(840), O3(113549), 2, 2 };
313 uchar oidmd5[] = { O0(1, 2), O2(840), O3(113549), 2, 5 };
314
315 /*
316 * DigestInfo ::= SEQUENCE {
317 * digestAlgorithm AlgorithmIdentifier,
318 * digest OCTET STRING
319 * }
320 *
321 * except that OpenSSL seems to sign
322 *
323 * DigestInfo ::= SEQUENCE {
324 * SEQUENCE{ digestAlgorithm AlgorithmIdentifier, NULL }
325 * digest OCTET STRING
326 * }
327 *
328 * instead. Sigh.
329 */
330 static int
mkasn1(uchar * asn1,char * alg,uchar * d,uint dlen)331 mkasn1(uchar *asn1, char *alg, uchar *d, uint dlen)
332 {
333 uchar *obj, *p;
334 uint olen;
335
336 if(strcmp(alg, "sha1") == 0){
337 obj = oidsha1;
338 olen = sizeof(oidsha1);
339 }else if(strcmp(alg, "md5") == 0){
340 obj = oidmd5;
341 olen = sizeof(oidmd5);
342 }else{
343 sysfatal("bad alg in mkasn1");
344 return -1;
345 }
346
347 p = asn1;
348 *p++ = 0x30; /* sequence */
349 p++;
350
351 *p++ = 0x30; /* another sequence */
352 p++;
353
354 *p++ = 0x06; /* object id */
355 *p++ = olen;
356 memmove(p, obj, olen);
357 p += olen;
358
359 *p++ = 0x05; /* null */
360 *p++ = 0;
361
362 asn1[3] = p - (asn1+4); /* end of inner sequence */
363
364 *p++ = 0x04; /* octet string */
365 *p++ = dlen;
366 memmove(p, d, dlen);
367 p += dlen;
368
369 asn1[1] = p - (asn1+2); /* end of outer sequence */
370 return p - asn1;
371 }
372
373 static mpint*
mkdigest(RSApub * key,char * hashalg,uchar * hash,uint dlen)374 mkdigest(RSApub *key, char *hashalg, uchar *hash, uint dlen)
375 {
376 mpint *m;
377 uchar asn1[512], *buf;
378 int len, n, pad;
379
380 /*
381 * Create ASN.1
382 */
383 n = mkasn1(asn1, hashalg, hash, dlen);
384
385 /*
386 * PKCS#1 padding
387 */
388 len = (mpsignif(key->n)+7)/8 - 1;
389 if(len < n+2){
390 werrstr("rsa key too short");
391 return nil;
392 }
393 pad = len - (n+2);
394 buf = emalloc(len);
395 buf[0] = 0x01;
396 memset(buf+1, 0xFF, pad);
397 buf[1+pad] = 0x00;
398 memmove(buf+1+pad+1, asn1, n);
399 m = betomp(buf, len, nil);
400 free(buf);
401 return m;
402 }
403