xref: /plan9/sys/src/cmd/ssh1/msg.c (revision 63afb9a5d3f910047231762bcce0ee49fed3d07c)
1 #include "ssh.h"
2 
3 static ulong sum32(ulong, void*, int);
4 
5 char *msgnames[] =
6 {
7 /* 0 */
8 	"SSH_MSG_NONE",
9 	"SSH_MSG_DISCONNECT",
10 	"SSH_SMSG_PUBLIC_KEY",
11 	"SSH_CMSG_SESSION_KEY",
12 	"SSH_CMSG_USER",
13 	"SSH_CMSG_AUTH_RHOSTS",
14 	"SSH_CMSG_AUTH_RSA",
15 	"SSH_SMSG_AUTH_RSA_CHALLENGE",
16 	"SSH_CMSG_AUTH_RSA_RESPONSE",
17 	"SSH_CMSG_AUTH_PASSWORD",
18 
19 /* 10 */
20 	"SSH_CMSG_REQUEST_PTY",
21 	"SSH_CMSG_WINDOW_SIZE",
22 	"SSH_CMSG_EXEC_SHELL",
23 	"SSH_CMSG_EXEC_CMD",
24 	"SSH_SMSG_SUCCESS",
25 	"SSH_SMSG_FAILURE",
26 	"SSH_CMSG_STDIN_DATA",
27 	"SSH_SMSG_STDOUT_DATA",
28 	"SSH_SMSG_STDERR_DATA",
29 	"SSH_CMSG_EOF",
30 
31 /* 20 */
32 	"SSH_SMSG_EXITSTATUS",
33 	"SSH_MSG_CHANNEL_OPEN_CONFIRMATION",
34 	"SSH_MSG_CHANNEL_OPEN_FAILURE",
35 	"SSH_MSG_CHANNEL_DATA",
36 	"SSH_MSG_CHANNEL_INPUT_EOF",
37 	"SSH_MSG_CHANNEL_OUTPUT_CLOSED",
38 	"SSH_MSG_UNIX_DOMAIN_X11_FORWARDING (obsolete)",
39 	"SSH_SMSG_X11_OPEN",
40 	"SSH_CMSG_PORT_FORWARD_REQUEST",
41 	"SSH_MSG_PORT_OPEN",
42 
43 /* 30 */
44 	"SSH_CMSG_AGENT_REQUEST_FORWARDING",
45 	"SSH_SMSG_AGENT_OPEN",
46 	"SSH_MSG_IGNORE",
47 	"SSH_CMSG_EXIT_CONFIRMATION",
48 	"SSH_CMSG_X11_REQUEST_FORWARDING",
49 	"SSH_CMSG_AUTH_RHOSTS_RSA",
50 	"SSH_MSG_DEBUG",
51 	"SSH_CMSG_REQUEST_COMPRESSION",
52 	"SSH_CMSG_MAX_PACKET_SIZE",
53 	"SSH_CMSG_AUTH_TIS",
54 
55 /* 40 */
56 	"SSH_SMSG_AUTH_TIS_CHALLENGE",
57 	"SSH_CMSG_AUTH_TIS_RESPONSE",
58 	"SSH_CMSG_AUTH_KERBEROS",
59 	"SSH_SMSG_AUTH_KERBEROS_RESPONSE",
60 	"SSH_CMSG_HAVE_KERBEROS_TGT"
61 };
62 
63 void
badmsg(Msg * m,int want)64 badmsg(Msg *m, int want)
65 {
66 	char *s, buf[20+ERRMAX];
67 
68 	if(m==nil){
69 		snprint(buf, sizeof buf, "<early eof: %r>");
70 		s = buf;
71 	}else{
72 		snprint(buf, sizeof buf, "<unknown type %d>", m->type);
73 		s = buf;
74 		if(0 <= m->type && m->type < nelem(msgnames))
75 			s = msgnames[m->type];
76 	}
77 	if(want)
78 		error("got %s message expecting %s", s, msgnames[want]);
79 	error("got unexpected %s message", s);
80 }
81 
82 Msg*
allocmsg(Conn * c,int type,int len)83 allocmsg(Conn *c, int type, int len)
84 {
85 	uchar *p;
86 	Msg *m;
87 
88 	if(len > 256*1024)
89 		abort();
90 
91 	m = (Msg*)emalloc(sizeof(Msg)+4+8+1+len+4);
92 	setmalloctag(m, getcallerpc(&c));
93 	p = (uchar*)&m[1];
94 	m->c = c;
95 	m->bp = p;
96 	m->ep = p+len;
97 	m->wp = p;
98 	m->type = type;
99 	return m;
100 }
101 
102 void
unrecvmsg(Conn * c,Msg * m)103 unrecvmsg(Conn *c, Msg *m)
104 {
105 	debug(DBG_PROTO, "unreceived %s len %ld\n", msgnames[m->type], m->ep - m->rp);
106 	free(c->unget);
107 	c->unget = m;
108 }
109 
110 static Msg*
recvmsg0(Conn * c)111 recvmsg0(Conn *c)
112 {
113 	int pad;
114 	uchar *p, buf[4];
115 	ulong crc, crc0, len;
116 	Msg *m;
117 
118 	if(c->unget){
119 		m = c->unget;
120 		c->unget = nil;
121 		return m;
122 	}
123 
124 	if(readn(c->fd[0], buf, 4) != 4){
125 		werrstr("short net read: %r");
126 		return nil;
127 	}
128 
129 	len = LONG(buf);
130 	if(len > 256*1024){
131 		werrstr("packet size far too big: %.8lux", len);
132 		return nil;
133 	}
134 
135 	pad = 8 - len%8;
136 
137 	m = (Msg*)emalloc(sizeof(Msg)+pad+len);
138 	setmalloctag(m, getcallerpc(&c));
139 	m->c = c;
140 	m->bp = (uchar*)&m[1];
141 	m->ep = m->bp + pad+len-4;	/* -4: don't include crc */
142 	m->rp = m->bp;
143 
144 	if(readn(c->fd[0], m->bp, pad+len) != pad+len){
145 		werrstr("short net read: %r");
146 		free(m);
147 		return nil;
148 	}
149 
150 	if(c->cipher)
151 		c->cipher->decrypt(c->cstate, m->bp, len+pad);
152 
153 	crc = sum32(0, m->bp, pad+len-4);
154 	p = m->bp + pad+len-4;
155 	crc0 = LONG(p);
156 	if(crc != crc0){
157 		werrstr("bad crc %#lux != %#lux (packet length %lud)", crc, crc0, len);
158 		free(m);
159 		return nil;
160 	}
161 
162 	m->rp += pad;
163 	m->type = *m->rp++;
164 
165 	return m;
166 }
167 
168 Msg*
recvmsg(Conn * c,int type)169 recvmsg(Conn *c, int type)
170 {
171 	Msg *m;
172 
173 	while((m = recvmsg0(c)) != nil){
174 		debug(DBG_PROTO, "received %s len %ld\n", msgnames[m->type], m->ep - m->rp);
175 		if(m->type != SSH_MSG_DEBUG && m->type != SSH_MSG_IGNORE)
176 			break;
177 		if(m->type == SSH_MSG_DEBUG)
178 			debug(DBG_PROTO, "remote DEBUG: %s\n", getstring(m));
179 		free(m);
180 	}
181 	if(type == 0){
182 		/* no checking */
183 	}else if(type == -1){
184 		/* must not be nil */
185 		if(m == nil)
186 			error(Ehangup);
187 	}else{
188 		/* must be given type */
189 		if(m==nil || m->type!=type)
190 			badmsg(m, type);
191 	}
192 	setmalloctag(m, getcallerpc(&c));
193 	return m;
194 }
195 
196 int
sendmsg(Msg * m)197 sendmsg(Msg *m)
198 {
199 	int i, pad;
200 	uchar *p;
201 	ulong datalen, len, crc;
202 	Conn *c;
203 
204 	datalen = m->wp - m->bp;
205 	len = datalen + 5;
206 	pad = 8 - len%8;
207 
208 	debug(DBG_PROTO, "sending %s len %lud\n", msgnames[m->type], datalen);
209 
210 	p = m->bp;
211 	memmove(m->bp+4+pad+1, m->bp, datalen);	/* slide data to correct position */
212 
213 	PLONG(p, len);
214 	p += 4;
215 
216 	if(m->c->cstate){
217 		for(i=0; i<pad; i++)
218 			*p++ = fastrand();
219 	}else{
220 		memset(p, 0, pad);
221 		p += pad;
222 	}
223 
224 	*p++ = m->type;
225 
226 	/* data already in position */
227 	p += datalen;
228 
229 	crc = sum32(0, m->bp+4, pad+1+datalen);
230 	PLONG(p, crc);
231 	p += 4;
232 
233 	c = m->c;
234 	qlock(c);
235 	if(c->cstate)
236 		c->cipher->encrypt(c->cstate, m->bp+4, len+pad);
237 
238 	if(write(c->fd[1], m->bp, p - m->bp) != p-m->bp){
239 		qunlock(c);
240 		free(m);
241 		return -1;
242 	}
243 	qunlock(c);
244 	free(m);
245 	return 0;
246 }
247 
248 uchar
getbyte(Msg * m)249 getbyte(Msg *m)
250 {
251 	if(m->rp >= m->ep)
252 		error(Edecode);
253 	return *m->rp++;
254 }
255 
256 ushort
getshort(Msg * m)257 getshort(Msg *m)
258 {
259 	ushort x;
260 
261 	if(m->rp+2 > m->ep)
262 		error(Edecode);
263 
264 	x = SHORT(m->rp);
265 	m->rp += 2;
266 	return x;
267 }
268 
269 ulong
getlong(Msg * m)270 getlong(Msg *m)
271 {
272 	ulong x;
273 
274 	if(m->rp+4 > m->ep)
275 		error(Edecode);
276 
277 	x = LONG(m->rp);
278 	m->rp += 4;
279 	return x;
280 }
281 
282 char*
getstring(Msg * m)283 getstring(Msg *m)
284 {
285 	char *p;
286 	ulong len;
287 
288 	/* overwrites length to make room for NUL */
289 	len = getlong(m);
290 	if(m->rp+len > m->ep)
291 		error(Edecode);
292 	p = (char*)m->rp-1;
293 	memmove(p, m->rp, len);
294 	p[len] = '\0';
295 	return p;
296 }
297 
298 void*
getbytes(Msg * m,int n)299 getbytes(Msg *m, int n)
300 {
301 	uchar *p;
302 
303 	if(m->rp+n > m->ep)
304 		error(Edecode);
305 	p = m->rp;
306 	m->rp += n;
307 	return p;
308 }
309 
310 mpint*
getmpint(Msg * m)311 getmpint(Msg *m)
312 {
313 	int n;
314 
315 	n = (getshort(m)+7)/8;	/* getshort returns # bits */
316 	return betomp(getbytes(m, n), n, nil);
317 }
318 
319 RSApub*
getRSApub(Msg * m)320 getRSApub(Msg *m)
321 {
322 	RSApub *key;
323 
324 	getlong(m);
325 	key = rsapuballoc();
326 	if(key == nil)
327 		error(Ememory);
328 	key->ek = getmpint(m);
329 	key->n = getmpint(m);
330 	setmalloctag(key, getcallerpc(&m));
331 	return key;
332 }
333 
334 void
putbyte(Msg * m,uchar x)335 putbyte(Msg *m, uchar x)
336 {
337 	if(m->wp >= m->ep)
338 		error(Eencode);
339 	*m->wp++ = x;
340 }
341 
342 void
putshort(Msg * m,ushort x)343 putshort(Msg *m, ushort x)
344 {
345 	if(m->wp+2 > m->ep)
346 		error(Eencode);
347 	PSHORT(m->wp, x);
348 	m->wp += 2;
349 }
350 
351 void
putlong(Msg * m,ulong x)352 putlong(Msg *m, ulong x)
353 {
354 	if(m->wp+4 > m->ep)
355 		error(Eencode);
356 	PLONG(m->wp, x);
357 	m->wp += 4;
358 }
359 
360 void
putstring(Msg * m,char * s)361 putstring(Msg *m, char *s)
362 {
363 	int len;
364 
365 	len = strlen(s);
366 	putlong(m, len);
367 	putbytes(m, s, len);
368 }
369 
370 void
putbytes(Msg * m,void * a,long n)371 putbytes(Msg *m, void *a, long n)
372 {
373 	if(m->wp+n > m->ep)
374 		error(Eencode);
375 	memmove(m->wp, a, n);
376 	m->wp += n;
377 }
378 
379 void
putmpint(Msg * m,mpint * b)380 putmpint(Msg *m, mpint *b)
381 {
382 	int bits, n;
383 
384 	bits = mpsignif(b);
385 	putshort(m, bits);
386 	n = (bits+7)/8;
387 	if(m->wp+n > m->ep)
388 		error(Eencode);
389 	mptobe(b, m->wp, n, nil);
390 	m->wp += n;
391 }
392 
393 void
putRSApub(Msg * m,RSApub * key)394 putRSApub(Msg *m, RSApub *key)
395 {
396 	putlong(m, mpsignif(key->n));
397 	putmpint(m, key->ek);
398 	putmpint(m, key->n);
399 }
400 
401 static ulong crctab[256];
402 
403 static void
initsum32(void)404 initsum32(void)
405 {
406 	ulong crc, poly;
407 	int i, j;
408 
409 	poly = 0xEDB88320;
410 	for(i = 0; i < 256; i++){
411 		crc = i;
412 		for(j = 0; j < 8; j++){
413 			if(crc & 1)
414 				crc = (crc >> 1) ^ poly;
415 			else
416 				crc >>= 1;
417 		}
418 		crctab[i] = crc;
419 	}
420 }
421 
422 static ulong
sum32(ulong lcrc,void * buf,int n)423 sum32(ulong lcrc, void *buf, int n)
424 {
425 	static int first=1;
426 	uchar *s = buf;
427 	ulong crc = lcrc;
428 
429 	if(first){
430 		first=0;
431 		initsum32();
432 	}
433 	while(n-- > 0)
434 		crc = crctab[(crc^*s++)&0xff] ^ (crc>>8);
435 	return crc;
436 }
437 
438 mpint*
rsapad(mpint * b,int n)439 rsapad(mpint *b, int n)
440 {
441 	int i, pad, nbuf;
442 	uchar buf[2560];
443 	mpint *c;
444 
445 	if(n > sizeof buf)
446 		error("buffer too small in rsapad");
447 
448 	nbuf = (mpsignif(b)+7)/8;
449 	pad = n - nbuf;
450 	assert(pad >= 3);
451 	mptobe(b, buf, nbuf, nil);
452 	memmove(buf+pad, buf, nbuf);
453 
454 	buf[0] = 0;
455 	buf[1] = 2;
456 	for(i=2; i<pad-1; i++)
457 		buf[i]=1+fastrand()%255;
458 	buf[pad-1] = 0;
459 	c = betomp(buf, n, nil);
460 	memset(buf, 0, sizeof buf);
461 	return c;
462 }
463 
464 mpint*
rsaunpad(mpint * b)465 rsaunpad(mpint *b)
466 {
467 	int i, n;
468 	uchar buf[2560];
469 
470 	n = (mpsignif(b)+7)/8;
471 	if(n > sizeof buf)
472 		error("buffer too small in rsaunpad");
473 	mptobe(b, buf, n, nil);
474 
475 	/* the initial zero has been eaten by the betomp -> mptobe sequence */
476 	if(buf[0] != 2)
477 		error("bad data in rsaunpad");
478 	for(i=1; i<n; i++)
479 		if(buf[i]==0)
480 			break;
481 	return betomp(buf+i, n-i, nil);
482 }
483 
484 void
mptoberjust(mpint * b,uchar * buf,int len)485 mptoberjust(mpint *b, uchar *buf, int len)
486 {
487 	int n;
488 
489 	n = mptobe(b, buf, len, nil);
490 	assert(n >= 0);
491 	if(n < len){
492 		len -= n;
493 		memmove(buf+len, buf, n);
494 		memset(buf, 0, len);
495 	}
496 }
497 
498 mpint*
rsaencryptbuf(RSApub * key,uchar * buf,int nbuf)499 rsaencryptbuf(RSApub *key, uchar *buf, int nbuf)
500 {
501 	int n;
502 	mpint *a, *b, *c;
503 
504 	n = (mpsignif(key->n)+7)/8;
505 	a = betomp(buf, nbuf, nil);
506 	b = rsapad(a, n);
507 	mpfree(a);
508 	c = rsaencrypt(key, b, nil);
509 	mpfree(b);
510 	return c;
511 }
512 
513