xref: /plan9/sys/src/cmd/auth/factotum/rsa.c (revision 7f0b57c67146be0d1988c0e09c92be045222e800)
1 /*
2  * RSA authentication.
3  *
4  * Old ssh client protocol:
5  *	read public key
6  *		if you don't like it, read another, repeat
7  *	write challenge
8  *	read response
9  *
10  * all numbers are hexadecimal biginits parsable with strtomp.
11  *
12  * Sign (PKCS #1 using hash=sha1 or hash=md5)
13  *	write hash(msg)
14  *	read signature(hash(msg))
15  *
16  * Verify:
17  *	write hash(msg)
18  *	write signature(hash(msg))
19  *	read ok or fail
20  */
21 
22 #include "dat.h"
23 
24 enum {
25 	CHavePub,
26 	CHaveResp,
27 	VNeedHash,
28 	VNeedSig,
29 	VHaveResp,
30 	SNeedHash,
31 	SHaveResp,
32 	Maxphase,
33 };
34 
35 static char *phasenames[] = {
36 [CHavePub]	"CHavePub",
37 [CHaveResp]	"CHaveResp",
38 [VNeedHash]	"VNeedHash",
39 [VNeedSig]	"VNeedSig",
40 [VHaveResp]	"VHaveResp",
41 [SNeedHash]	"SNeedHash",
42 [SHaveResp]	"SHaveResp",
43 };
44 
45 struct State
46 {
47 	RSApriv *priv;
48 	mpint *resp;
49 	int off;
50 	Key *key;
51 	mpint *digest;
52 	int sigresp;
53 };
54 
55 static mpint* mkdigest(RSApub *key, char *hashalg, uchar *hash, uint dlen);
56 
57 static RSApriv*
readrsapriv(Key * k)58 readrsapriv(Key *k)
59 {
60 	char *a;
61 	RSApriv *priv;
62 
63 	priv = rsaprivalloc();
64 
65 	if((a=_strfindattr(k->attr, "ek"))==nil || (priv->pub.ek=strtomp(a, nil, 16, nil))==nil)
66 		goto Error;
67 	if((a=_strfindattr(k->attr, "n"))==nil || (priv->pub.n=strtomp(a, nil, 16, nil))==nil)
68 		goto Error;
69 	if(k->privattr == nil)		/* only public half */
70 		return priv;
71 	if((a=_strfindattr(k->privattr, "!p"))==nil || (priv->p=strtomp(a, nil, 16, nil))==nil)
72 		goto Error;
73 	if((a=_strfindattr(k->privattr, "!q"))==nil || (priv->q=strtomp(a, nil, 16, nil))==nil)
74 		goto Error;
75 	if((a=_strfindattr(k->privattr, "!kp"))==nil || (priv->kp=strtomp(a, nil, 16, nil))==nil)
76 		goto Error;
77 	if((a=_strfindattr(k->privattr, "!kq"))==nil || (priv->kq=strtomp(a, nil, 16, nil))==nil)
78 		goto Error;
79 	if((a=_strfindattr(k->privattr, "!c2"))==nil || (priv->c2=strtomp(a, nil, 16, nil))==nil)
80 		goto Error;
81 	if((a=_strfindattr(k->privattr, "!dk"))==nil || (priv->dk=strtomp(a, nil, 16, nil))==nil)
82 		goto Error;
83 	return priv;
84 
85 Error:
86 	rsaprivfree(priv);
87 	return nil;
88 }
89 
90 static int
rsainit(Proto *,Fsstate * fss)91 rsainit(Proto*, Fsstate *fss)
92 {
93 	Keyinfo ki;
94 	State *s;
95 	char *role;
96 
97 	if((role = _strfindattr(fss->attr, "role")) == nil)
98 		return failure(fss, "rsa role not specified");
99 	if(strcmp(role, "client") == 0)
100 		fss->phase = CHavePub;
101 	else if(strcmp(role, "sign") == 0)
102 		fss->phase = SNeedHash;
103 	else if(strcmp(role, "verify") == 0)
104 		fss->phase = VNeedHash;
105 	else
106 		return failure(fss, "rsa role %s unimplemented", role);
107 
108 	s = emalloc(sizeof *s);
109 	fss->phasename = phasenames;
110 	fss->maxphase = Maxphase;
111 	fss->ps = s;
112 
113 	switch(fss->phase){
114 	case SNeedHash:
115 	case VNeedHash:
116 		mkkeyinfo(&ki, fss, nil);
117 		if(findkey(&s->key, &ki, nil) != RpcOk)
118 			return failure(fss, nil);
119 		/* signing needs private key */
120 		if(fss->phase == SNeedHash && s->key->privattr == nil)
121 			return failure(fss,
122 				"missing private half of key -- cannot sign");
123 	}
124 	return RpcOk;
125 }
126 
127 static int
rsaread(Fsstate * fss,void * va,uint * n)128 rsaread(Fsstate *fss, void *va, uint *n)
129 {
130 	RSApriv *priv;
131 	State *s;
132 	mpint *m;
133 	Keyinfo ki;
134 	int len, r;
135 
136 	s = fss->ps;
137 	switch(fss->phase){
138 	default:
139 		return phaseerror(fss, "read");
140 	case CHavePub:
141 		if(s->key){
142 			closekey(s->key);
143 			s->key = nil;
144 		}
145 		mkkeyinfo(&ki, fss, nil);
146 		ki.skip = s->off;
147 		ki.noconf = 1;
148 		if(findkey(&s->key, &ki, nil) != RpcOk)
149 			return failure(fss, nil);
150 		s->off++;
151 		priv = s->key->priv;
152 		*n = snprint(va, *n, "%B", priv->pub.n);
153 		return RpcOk;
154 	case CHaveResp:
155 		*n = snprint(va, *n, "%B", s->resp);
156 		fss->phase = Established;
157 		return RpcOk;
158 	case SHaveResp:
159 		priv = s->key->priv;
160 		len = (mpsignif(priv->pub.n)+7)/8;
161 		if(len > *n)
162 			return failure(fss, "signature buffer too short");
163 		m = rsadecrypt(priv, s->digest, nil);
164 		r = mptobe(m, (uchar*)va, len, nil);
165 		if(r < len){
166 			memmove((uchar*)va+len-r, va, r);
167 			memset(va, 0, len-r);
168 		}
169 		*n = len;
170 		mpfree(m);
171 		fss->phase = Established;
172 		return RpcOk;
173 	case VHaveResp:
174 		*n = snprint(va, *n, "%s", s->sigresp == 0? "ok":
175 			"signature does not verify");
176 		fss->phase = Established;
177 		return RpcOk;
178 	}
179 }
180 
181 static int
rsawrite(Fsstate * fss,void * va,uint n)182 rsawrite(Fsstate *fss, void *va, uint n)
183 {
184 	RSApriv *priv;
185 	mpint *m, *mm;
186 	State *s;
187 	char *hash;
188 	int dlen;
189 
190 	s = fss->ps;
191 	switch(fss->phase){
192 	default:
193 		return phaseerror(fss, "write");
194 	case CHavePub:
195 		if(s->key == nil)
196 			return failure(fss, "no current key");
197 		switch(canusekey(fss, s->key)){
198 		case -1:
199 			return RpcConfirm;
200 		case 0:
201 			return failure(fss, "confirmation denied");
202 		case 1:
203 			break;
204 		}
205 		m = strtomp(va, nil, 16, nil);
206 		if(m == nil)
207 			return failure(fss, "invalid challenge value");
208 		m = rsadecrypt(s->key->priv, m, m);
209 		s->resp = m;
210 		fss->phase = CHaveResp;
211 		return RpcOk;
212 	case SNeedHash:
213 	case VNeedHash:
214 		/* get hash type from key */
215 		hash = _strfindattr(s->key->attr, "hash");
216 		if(hash == nil)
217 			hash = "sha1";
218 		if(strcmp(hash, "sha1") == 0)
219 			dlen = SHA1dlen;
220 		else if(strcmp(hash, "md5") == 0)
221 			dlen = MD5dlen;
222 		else
223 			return failure(fss, "unknown hash function %s", hash);
224 		if(n != dlen)
225 			return failure(fss, "hash length %d should be %d",
226 				n, dlen);
227 		priv = s->key->priv;
228 		s->digest = mkdigest(&priv->pub, hash, (uchar *)va, n);
229 		if(s->digest == nil)
230 			return failure(fss, nil);
231 		if(fss->phase == VNeedHash)
232 			fss->phase = VNeedSig;
233 		else
234 			fss->phase = SHaveResp;
235 		return RpcOk;
236 	case VNeedSig:
237 		priv = s->key->priv;
238 		m = betomp((uchar*)va, n, nil);
239 		mm = rsaencrypt(&priv->pub, m, nil);
240 		s->sigresp = mpcmp(s->digest, mm);
241 		mpfree(m);
242 		mpfree(mm);
243 		fss->phase = VHaveResp;
244 		return RpcOk;
245 	}
246 }
247 
248 static void
rsaclose(Fsstate * fss)249 rsaclose(Fsstate *fss)
250 {
251 	State *s;
252 
253 	s = fss->ps;
254 	if(s->key)
255 		closekey(s->key);
256 	if(s->resp)
257 		mpfree(s->resp);
258 	if(s->digest)
259 		mpfree(s->digest);
260 	free(s);
261 }
262 
263 static int
rsaaddkey(Key * k,int before)264 rsaaddkey(Key *k, int before)
265 {
266 	fmtinstall('B', mpfmt);
267 
268 	if((k->priv = readrsapriv(k)) == nil){
269 		werrstr("malformed key data");
270 		return -1;
271 	}
272 	return replacekey(k, before);
273 }
274 
275 static void
rsaclosekey(Key * k)276 rsaclosekey(Key *k)
277 {
278 	rsaprivfree(k->priv);
279 }
280 
281 Proto rsa = {
282 .name=	"rsa",
283 .init=		rsainit,
284 .write=	rsawrite,
285 .read=	rsaread,
286 .close=	rsaclose,
287 .addkey=	rsaaddkey,
288 .closekey=	rsaclosekey,
289 };
290 
291 /*
292  * Simple ASN.1 encodings.
293  * Lengths < 128 are encoded as 1-bytes constants,
294  * making our life easy.
295  */
296 
297 /*
298  * Hash OIDs
299  *
300  * SHA1 = 1.3.14.3.2.26
301  * MDx = 1.2.840.113549.2.x
302  */
303 #define O0(a,b)	((a)*40+(b))
304 #define O2(x)	\
305 	(((x)>> 7)&0x7F)|0x80, \
306 	((x)&0x7F)
307 #define O3(x)	\
308 	(((x)>>14)&0x7F)|0x80, \
309 	(((x)>> 7)&0x7F)|0x80, \
310 	((x)&0x7F)
311 uchar oidsha1[] = { O0(1, 3), 14, 3, 2, 26 };
312 uchar oidmd2[] = { O0(1, 2), O2(840), O3(113549), 2, 2 };
313 uchar oidmd5[] = { O0(1, 2), O2(840), O3(113549), 2, 5 };
314 
315 /*
316  *	DigestInfo ::= SEQUENCE {
317  *		digestAlgorithm AlgorithmIdentifier,
318  *		digest OCTET STRING
319  *	}
320  *
321  * except that OpenSSL seems to sign
322  *
323  *	DigestInfo ::= SEQUENCE {
324  *		SEQUENCE{ digestAlgorithm AlgorithmIdentifier, NULL }
325  *		digest OCTET STRING
326  *	}
327  *
328  * instead.  Sigh.
329  */
330 static int
mkasn1(uchar * asn1,char * alg,uchar * d,uint dlen)331 mkasn1(uchar *asn1, char *alg, uchar *d, uint dlen)
332 {
333 	uchar *obj, *p;
334 	uint olen;
335 
336 	if(strcmp(alg, "sha1") == 0){
337 		obj = oidsha1;
338 		olen = sizeof(oidsha1);
339 	}else if(strcmp(alg, "md5") == 0){
340 		obj = oidmd5;
341 		olen = sizeof(oidmd5);
342 	}else{
343 		sysfatal("bad alg in mkasn1");
344 		return -1;
345 	}
346 
347 	p = asn1;
348 	*p++ = 0x30;		/* sequence */
349 	p++;
350 
351 	*p++ = 0x30;		/* another sequence */
352 	p++;
353 
354 	*p++ = 0x06;		/* object id */
355 	*p++ = olen;
356 	memmove(p, obj, olen);
357 	p += olen;
358 
359 	*p++ = 0x05;		/* null */
360 	*p++ = 0;
361 
362 	asn1[3] = p - (asn1+4);	/* end of inner sequence */
363 
364 	*p++ = 0x04;		/* octet string */
365 	*p++ = dlen;
366 	memmove(p, d, dlen);
367 	p += dlen;
368 
369 	asn1[1] = p - (asn1+2);	/* end of outer sequence */
370 	return p - asn1;
371 }
372 
373 static mpint*
mkdigest(RSApub * key,char * hashalg,uchar * hash,uint dlen)374 mkdigest(RSApub *key, char *hashalg, uchar *hash, uint dlen)
375 {
376 	mpint *m;
377 	uchar asn1[512], *buf;
378 	int len, n, pad;
379 
380 	/*
381 	 * Create ASN.1
382 	 */
383 	n = mkasn1(asn1, hashalg, hash, dlen);
384 
385 	/*
386 	 * PKCS#1 padding
387 	 */
388 	len = (mpsignif(key->n)+7)/8 - 1;
389 	if(len < n+2){
390 		werrstr("rsa key too short");
391 		return nil;
392 	}
393 	pad = len - (n+2);
394 	buf = emalloc(len);
395 	buf[0] = 0x01;
396 	memset(buf+1, 0xFF, pad);
397 	buf[1+pad] = 0x00;
398 	memmove(buf+1+pad+1, asn1, n);
399 	m = betomp(buf, len, nil);
400 	free(buf);
401 	return m;
402 }
403