1 #include "ssh.h"
2
3 static ulong sum32(ulong, void*, int);
4
5 char *msgnames[] =
6 {
7 /* 0 */
8 "SSH_MSG_NONE",
9 "SSH_MSG_DISCONNECT",
10 "SSH_SMSG_PUBLIC_KEY",
11 "SSH_CMSG_SESSION_KEY",
12 "SSH_CMSG_USER",
13 "SSH_CMSG_AUTH_RHOSTS",
14 "SSH_CMSG_AUTH_RSA",
15 "SSH_SMSG_AUTH_RSA_CHALLENGE",
16 "SSH_CMSG_AUTH_RSA_RESPONSE",
17 "SSH_CMSG_AUTH_PASSWORD",
18
19 /* 10 */
20 "SSH_CMSG_REQUEST_PTY",
21 "SSH_CMSG_WINDOW_SIZE",
22 "SSH_CMSG_EXEC_SHELL",
23 "SSH_CMSG_EXEC_CMD",
24 "SSH_SMSG_SUCCESS",
25 "SSH_SMSG_FAILURE",
26 "SSH_CMSG_STDIN_DATA",
27 "SSH_SMSG_STDOUT_DATA",
28 "SSH_SMSG_STDERR_DATA",
29 "SSH_CMSG_EOF",
30
31 /* 20 */
32 "SSH_SMSG_EXITSTATUS",
33 "SSH_MSG_CHANNEL_OPEN_CONFIRMATION",
34 "SSH_MSG_CHANNEL_OPEN_FAILURE",
35 "SSH_MSG_CHANNEL_DATA",
36 "SSH_MSG_CHANNEL_INPUT_EOF",
37 "SSH_MSG_CHANNEL_OUTPUT_CLOSED",
38 "SSH_MSG_UNIX_DOMAIN_X11_FORWARDING (obsolete)",
39 "SSH_SMSG_X11_OPEN",
40 "SSH_CMSG_PORT_FORWARD_REQUEST",
41 "SSH_MSG_PORT_OPEN",
42
43 /* 30 */
44 "SSH_CMSG_AGENT_REQUEST_FORWARDING",
45 "SSH_SMSG_AGENT_OPEN",
46 "SSH_MSG_IGNORE",
47 "SSH_CMSG_EXIT_CONFIRMATION",
48 "SSH_CMSG_X11_REQUEST_FORWARDING",
49 "SSH_CMSG_AUTH_RHOSTS_RSA",
50 "SSH_MSG_DEBUG",
51 "SSH_CMSG_REQUEST_COMPRESSION",
52 "SSH_CMSG_MAX_PACKET_SIZE",
53 "SSH_CMSG_AUTH_TIS",
54
55 /* 40 */
56 "SSH_SMSG_AUTH_TIS_CHALLENGE",
57 "SSH_CMSG_AUTH_TIS_RESPONSE",
58 "SSH_CMSG_AUTH_KERBEROS",
59 "SSH_SMSG_AUTH_KERBEROS_RESPONSE",
60 "SSH_CMSG_HAVE_KERBEROS_TGT"
61 };
62
63 void
badmsg(Msg * m,int want)64 badmsg(Msg *m, int want)
65 {
66 char *s, buf[20+ERRMAX];
67
68 if(m==nil){
69 snprint(buf, sizeof buf, "<early eof: %r>");
70 s = buf;
71 }else{
72 snprint(buf, sizeof buf, "<unknown type %d>", m->type);
73 s = buf;
74 if(0 <= m->type && m->type < nelem(msgnames))
75 s = msgnames[m->type];
76 }
77 if(want)
78 error("got %s message expecting %s", s, msgnames[want]);
79 error("got unexpected %s message", s);
80 }
81
82 Msg*
allocmsg(Conn * c,int type,int len)83 allocmsg(Conn *c, int type, int len)
84 {
85 uchar *p;
86 Msg *m;
87
88 if(len > 256*1024)
89 abort();
90
91 m = (Msg*)emalloc(sizeof(Msg)+4+8+1+len+4);
92 setmalloctag(m, getcallerpc(&c));
93 p = (uchar*)&m[1];
94 m->c = c;
95 m->bp = p;
96 m->ep = p+len;
97 m->wp = p;
98 m->type = type;
99 return m;
100 }
101
102 void
unrecvmsg(Conn * c,Msg * m)103 unrecvmsg(Conn *c, Msg *m)
104 {
105 debug(DBG_PROTO, "unreceived %s len %ld\n", msgnames[m->type], m->ep - m->rp);
106 free(c->unget);
107 c->unget = m;
108 }
109
110 static Msg*
recvmsg0(Conn * c)111 recvmsg0(Conn *c)
112 {
113 int pad;
114 uchar *p, buf[4];
115 ulong crc, crc0, len;
116 Msg *m;
117
118 if(c->unget){
119 m = c->unget;
120 c->unget = nil;
121 return m;
122 }
123
124 if(readn(c->fd[0], buf, 4) != 4){
125 werrstr("short net read: %r");
126 return nil;
127 }
128
129 len = LONG(buf);
130 if(len > 256*1024){
131 werrstr("packet size far too big: %.8lux", len);
132 return nil;
133 }
134
135 pad = 8 - len%8;
136
137 m = (Msg*)emalloc(sizeof(Msg)+pad+len);
138 setmalloctag(m, getcallerpc(&c));
139 m->c = c;
140 m->bp = (uchar*)&m[1];
141 m->ep = m->bp + pad+len-4; /* -4: don't include crc */
142 m->rp = m->bp;
143
144 if(readn(c->fd[0], m->bp, pad+len) != pad+len){
145 werrstr("short net read: %r");
146 free(m);
147 return nil;
148 }
149
150 if(c->cipher)
151 c->cipher->decrypt(c->cstate, m->bp, len+pad);
152
153 crc = sum32(0, m->bp, pad+len-4);
154 p = m->bp + pad+len-4;
155 crc0 = LONG(p);
156 if(crc != crc0){
157 werrstr("bad crc %#lux != %#lux (packet length %lud)", crc, crc0, len);
158 free(m);
159 return nil;
160 }
161
162 m->rp += pad;
163 m->type = *m->rp++;
164
165 return m;
166 }
167
168 Msg*
recvmsg(Conn * c,int type)169 recvmsg(Conn *c, int type)
170 {
171 Msg *m;
172
173 while((m = recvmsg0(c)) != nil){
174 debug(DBG_PROTO, "received %s len %ld\n", msgnames[m->type], m->ep - m->rp);
175 if(m->type != SSH_MSG_DEBUG && m->type != SSH_MSG_IGNORE)
176 break;
177 if(m->type == SSH_MSG_DEBUG)
178 debug(DBG_PROTO, "remote DEBUG: %s\n", getstring(m));
179 free(m);
180 }
181 if(type == 0){
182 /* no checking */
183 }else if(type == -1){
184 /* must not be nil */
185 if(m == nil)
186 error(Ehangup);
187 }else{
188 /* must be given type */
189 if(m==nil || m->type!=type)
190 badmsg(m, type);
191 }
192 setmalloctag(m, getcallerpc(&c));
193 return m;
194 }
195
196 int
sendmsg(Msg * m)197 sendmsg(Msg *m)
198 {
199 int i, pad;
200 uchar *p;
201 ulong datalen, len, crc;
202 Conn *c;
203
204 datalen = m->wp - m->bp;
205 len = datalen + 5;
206 pad = 8 - len%8;
207
208 debug(DBG_PROTO, "sending %s len %lud\n", msgnames[m->type], datalen);
209
210 p = m->bp;
211 memmove(m->bp+4+pad+1, m->bp, datalen); /* slide data to correct position */
212
213 PLONG(p, len);
214 p += 4;
215
216 if(m->c->cstate){
217 for(i=0; i<pad; i++)
218 *p++ = fastrand();
219 }else{
220 memset(p, 0, pad);
221 p += pad;
222 }
223
224 *p++ = m->type;
225
226 /* data already in position */
227 p += datalen;
228
229 crc = sum32(0, m->bp+4, pad+1+datalen);
230 PLONG(p, crc);
231 p += 4;
232
233 c = m->c;
234 qlock(c);
235 if(c->cstate)
236 c->cipher->encrypt(c->cstate, m->bp+4, len+pad);
237
238 if(write(c->fd[1], m->bp, p - m->bp) != p-m->bp){
239 qunlock(c);
240 free(m);
241 return -1;
242 }
243 qunlock(c);
244 free(m);
245 return 0;
246 }
247
248 uchar
getbyte(Msg * m)249 getbyte(Msg *m)
250 {
251 if(m->rp >= m->ep)
252 error(Edecode);
253 return *m->rp++;
254 }
255
256 ushort
getshort(Msg * m)257 getshort(Msg *m)
258 {
259 ushort x;
260
261 if(m->rp+2 > m->ep)
262 error(Edecode);
263
264 x = SHORT(m->rp);
265 m->rp += 2;
266 return x;
267 }
268
269 ulong
getlong(Msg * m)270 getlong(Msg *m)
271 {
272 ulong x;
273
274 if(m->rp+4 > m->ep)
275 error(Edecode);
276
277 x = LONG(m->rp);
278 m->rp += 4;
279 return x;
280 }
281
282 char*
getstring(Msg * m)283 getstring(Msg *m)
284 {
285 char *p;
286 ulong len;
287
288 /* overwrites length to make room for NUL */
289 len = getlong(m);
290 if(m->rp+len > m->ep)
291 error(Edecode);
292 p = (char*)m->rp-1;
293 memmove(p, m->rp, len);
294 p[len] = '\0';
295 return p;
296 }
297
298 void*
getbytes(Msg * m,int n)299 getbytes(Msg *m, int n)
300 {
301 uchar *p;
302
303 if(m->rp+n > m->ep)
304 error(Edecode);
305 p = m->rp;
306 m->rp += n;
307 return p;
308 }
309
310 mpint*
getmpint(Msg * m)311 getmpint(Msg *m)
312 {
313 int n;
314
315 n = (getshort(m)+7)/8; /* getshort returns # bits */
316 return betomp(getbytes(m, n), n, nil);
317 }
318
319 RSApub*
getRSApub(Msg * m)320 getRSApub(Msg *m)
321 {
322 RSApub *key;
323
324 getlong(m);
325 key = rsapuballoc();
326 if(key == nil)
327 error(Ememory);
328 key->ek = getmpint(m);
329 key->n = getmpint(m);
330 setmalloctag(key, getcallerpc(&m));
331 return key;
332 }
333
334 void
putbyte(Msg * m,uchar x)335 putbyte(Msg *m, uchar x)
336 {
337 if(m->wp >= m->ep)
338 error(Eencode);
339 *m->wp++ = x;
340 }
341
342 void
putshort(Msg * m,ushort x)343 putshort(Msg *m, ushort x)
344 {
345 if(m->wp+2 > m->ep)
346 error(Eencode);
347 PSHORT(m->wp, x);
348 m->wp += 2;
349 }
350
351 void
putlong(Msg * m,ulong x)352 putlong(Msg *m, ulong x)
353 {
354 if(m->wp+4 > m->ep)
355 error(Eencode);
356 PLONG(m->wp, x);
357 m->wp += 4;
358 }
359
360 void
putstring(Msg * m,char * s)361 putstring(Msg *m, char *s)
362 {
363 int len;
364
365 len = strlen(s);
366 putlong(m, len);
367 putbytes(m, s, len);
368 }
369
370 void
putbytes(Msg * m,void * a,long n)371 putbytes(Msg *m, void *a, long n)
372 {
373 if(m->wp+n > m->ep)
374 error(Eencode);
375 memmove(m->wp, a, n);
376 m->wp += n;
377 }
378
379 void
putmpint(Msg * m,mpint * b)380 putmpint(Msg *m, mpint *b)
381 {
382 int bits, n;
383
384 bits = mpsignif(b);
385 putshort(m, bits);
386 n = (bits+7)/8;
387 if(m->wp+n > m->ep)
388 error(Eencode);
389 mptobe(b, m->wp, n, nil);
390 m->wp += n;
391 }
392
393 void
putRSApub(Msg * m,RSApub * key)394 putRSApub(Msg *m, RSApub *key)
395 {
396 putlong(m, mpsignif(key->n));
397 putmpint(m, key->ek);
398 putmpint(m, key->n);
399 }
400
401 static ulong crctab[256];
402
403 static void
initsum32(void)404 initsum32(void)
405 {
406 ulong crc, poly;
407 int i, j;
408
409 poly = 0xEDB88320;
410 for(i = 0; i < 256; i++){
411 crc = i;
412 for(j = 0; j < 8; j++){
413 if(crc & 1)
414 crc = (crc >> 1) ^ poly;
415 else
416 crc >>= 1;
417 }
418 crctab[i] = crc;
419 }
420 }
421
422 static ulong
sum32(ulong lcrc,void * buf,int n)423 sum32(ulong lcrc, void *buf, int n)
424 {
425 static int first=1;
426 uchar *s = buf;
427 ulong crc = lcrc;
428
429 if(first){
430 first=0;
431 initsum32();
432 }
433 while(n-- > 0)
434 crc = crctab[(crc^*s++)&0xff] ^ (crc>>8);
435 return crc;
436 }
437
438 mpint*
rsapad(mpint * b,int n)439 rsapad(mpint *b, int n)
440 {
441 int i, pad, nbuf;
442 uchar buf[2560];
443 mpint *c;
444
445 if(n > sizeof buf)
446 error("buffer too small in rsapad");
447
448 nbuf = (mpsignif(b)+7)/8;
449 pad = n - nbuf;
450 assert(pad >= 3);
451 mptobe(b, buf, nbuf, nil);
452 memmove(buf+pad, buf, nbuf);
453
454 buf[0] = 0;
455 buf[1] = 2;
456 for(i=2; i<pad-1; i++)
457 buf[i]=1+fastrand()%255;
458 buf[pad-1] = 0;
459 c = betomp(buf, n, nil);
460 memset(buf, 0, sizeof buf);
461 return c;
462 }
463
464 mpint*
rsaunpad(mpint * b)465 rsaunpad(mpint *b)
466 {
467 int i, n;
468 uchar buf[2560];
469
470 n = (mpsignif(b)+7)/8;
471 if(n > sizeof buf)
472 error("buffer too small in rsaunpad");
473 mptobe(b, buf, n, nil);
474
475 /* the initial zero has been eaten by the betomp -> mptobe sequence */
476 if(buf[0] != 2)
477 error("bad data in rsaunpad");
478 for(i=1; i<n; i++)
479 if(buf[i]==0)
480 break;
481 return betomp(buf+i, n-i, nil);
482 }
483
484 void
mptoberjust(mpint * b,uchar * buf,int len)485 mptoberjust(mpint *b, uchar *buf, int len)
486 {
487 int n;
488
489 n = mptobe(b, buf, len, nil);
490 assert(n >= 0);
491 if(n < len){
492 len -= n;
493 memmove(buf+len, buf, n);
494 memset(buf, 0, len);
495 }
496 }
497
498 mpint*
rsaencryptbuf(RSApub * key,uchar * buf,int nbuf)499 rsaencryptbuf(RSApub *key, uchar *buf, int nbuf)
500 {
501 int n;
502 mpint *a, *b, *c;
503
504 n = (mpsignif(key->n)+7)/8;
505 a = betomp(buf, nbuf, nil);
506 b = rsapad(a, n);
507 mpfree(a);
508 c = rsaencrypt(key, b, nil);
509 mpfree(b);
510 return c;
511 }
512
513