1 #include "ssh.h" 2 3 static ulong sum32(ulong, void*, int); 4 5 char *msgnames[] = 6 { 7 /* 0 */ 8 "SSH_MSG_NONE", 9 "SSH_MSG_DISCONNECT", 10 "SSH_SMSG_PUBLIC_KEY", 11 "SSH_CMSG_SESSION_KEY", 12 "SSH_CMSG_USER", 13 "SSH_CMSG_AUTH_RHOSTS", 14 "SSH_CMSG_AUTH_RSA", 15 "SSH_SMSG_AUTH_RSA_CHALLENGE", 16 "SSH_CMSG_AUTH_RSA_RESPONSE", 17 "SSH_CMSG_AUTH_PASSWORD", 18 19 /* 10 */ 20 "SSH_CMSG_REQUEST_PTY", 21 "SSH_CMSG_WINDOW_SIZE", 22 "SSH_CMSG_EXEC_SHELL", 23 "SSH_CMSG_EXEC_CMD", 24 "SSH_SMSG_SUCCESS", 25 "SSH_SMSG_FAILURE", 26 "SSH_CMSG_STDIN_DATA", 27 "SSH_SMSG_STDOUT_DATA", 28 "SSH_SMSG_STDERR_DATA", 29 "SSH_CMSG_EOF", 30 31 /* 20 */ 32 "SSH_SMSG_EXITSTATUS", 33 "SSH_MSG_CHANNEL_OPEN_CONFIRMATION", 34 "SSH_MSG_CHANNEL_OPEN_FAILURE", 35 "SSH_MSG_CHANNEL_DATA", 36 "SSH_MSG_CHANNEL_INPUT_EOF", 37 "SSH_MSG_CHANNEL_OUTPUT_CLOSED", 38 "SSH_MSG_UNIX_DOMAIN_X11_FORWARDING (obsolete)", 39 "SSH_SMSG_X11_OPEN", 40 "SSH_CMSG_PORT_FORWARD_REQUEST", 41 "SSH_MSG_PORT_OPEN", 42 43 /* 30 */ 44 "SSH_CMSG_AGENT_REQUEST_FORWARDING", 45 "SSH_SMSG_AGENT_OPEN", 46 "SSH_MSG_IGNORE", 47 "SSH_CMSG_EXIT_CONFIRMATION", 48 "SSH_CMSG_X11_REQUEST_FORWARDING", 49 "SSH_CMSG_AUTH_RHOSTS_RSA", 50 "SSH_MSG_DEBUG", 51 "SSH_CMSG_REQUEST_COMPRESSION", 52 "SSH_CMSG_MAX_PACKET_SIZE", 53 "SSH_CMSG_AUTH_TIS", 54 55 /* 40 */ 56 "SSH_SMSG_AUTH_TIS_CHALLENGE", 57 "SSH_CMSG_AUTH_TIS_RESPONSE", 58 "SSH_CMSG_AUTH_KERBEROS", 59 "SSH_SMSG_AUTH_KERBEROS_RESPONSE", 60 "SSH_CMSG_HAVE_KERBEROS_TGT" 61 }; 62 63 void 64 badmsg(Msg *m, int want) 65 { 66 char *s, buf[20+ERRMAX]; 67 68 if(m==nil){ 69 snprint(buf, sizeof buf, "<early eof: %r>"); 70 s = buf; 71 }else{ 72 snprint(buf, sizeof buf, "<unknown type %d>", m->type); 73 s = buf; 74 if(0 <= m->type && m->type < nelem(msgnames)) 75 s = msgnames[m->type]; 76 } 77 if(want) 78 error("got %s message expecting %s", s, msgnames[want]); 79 error("got unexpected %s message", s); 80 } 81 82 Msg* 83 allocmsg(Conn *c, int type, int len) 84 { 85 uchar *p; 86 Msg *m; 87 88 if(len > 256*1024) 89 abort(); 90 91 m = (Msg*)emalloc(sizeof(Msg)+4+8+1+len+4); 92 setmalloctag(m, getcallerpc(&c)); 93 p = (uchar*)&m[1]; 94 m->c = c; 95 m->bp = p; 96 m->ep = p+len; 97 m->wp = p; 98 m->type = type; 99 return m; 100 } 101 102 void 103 unrecvmsg(Conn *c, Msg *m) 104 { 105 debug(DBG_PROTO, "unreceived %s len %ld\n", msgnames[m->type], m->ep - m->rp); 106 free(c->unget); 107 c->unget = m; 108 } 109 110 static Msg* 111 recvmsg0(Conn *c) 112 { 113 int pad; 114 uchar *p, buf[4]; 115 ulong crc, crc0, len; 116 Msg *m; 117 118 if(c->unget){ 119 m = c->unget; 120 c->unget = nil; 121 return m; 122 } 123 124 if(readn(c->fd[0], buf, 4) != 4){ 125 werrstr("short net read: %r"); 126 return nil; 127 } 128 129 len = LONG(buf); 130 if(len > 256*1024){ 131 werrstr("packet size far too big: %.8lux", len); 132 return nil; 133 } 134 135 pad = 8 - len%8; 136 137 m = (Msg*)emalloc(sizeof(Msg)+pad+len); 138 setmalloctag(m, getcallerpc(&c)); 139 m->c = c; 140 m->bp = (uchar*)&m[1]; 141 m->ep = m->bp + pad+len-4; /* -4: don't include crc */ 142 m->rp = m->bp; 143 144 if(readn(c->fd[0], m->bp, pad+len) != pad+len){ 145 werrstr("short net read: %r"); 146 free(m); 147 return nil; 148 } 149 150 if(c->cipher) 151 c->cipher->decrypt(c->cstate, m->bp, len+pad); 152 153 crc = sum32(0, m->bp, pad+len-4); 154 p = m->bp + pad+len-4; 155 crc0 = LONG(p); 156 if(crc != crc0){ 157 werrstr("bad crc %#lux != %#lux (packet length %lud)", crc, crc0, len); 158 free(m); 159 return nil; 160 } 161 162 m->rp += pad; 163 m->type = *m->rp++; 164 165 return m; 166 } 167 168 Msg* 169 recvmsg(Conn *c, int type) 170 { 171 Msg *m; 172 173 while((m = recvmsg0(c)) != nil){ 174 debug(DBG_PROTO, "received %s len %ld\n", msgnames[m->type], m->ep - m->rp); 175 if(m->type != SSH_MSG_DEBUG && m->type != SSH_MSG_IGNORE) 176 break; 177 if(m->type == SSH_MSG_DEBUG) 178 debug(DBG_PROTO, "remote DEBUG: %s\n", getstring(m)); 179 free(m); 180 } 181 if(type == 0){ 182 /* no checking */ 183 }else if(type == -1){ 184 /* must not be nil */ 185 if(m == nil) 186 error(Ehangup); 187 }else{ 188 /* must be given type */ 189 if(m==nil || m->type!=type) 190 badmsg(m, type); 191 } 192 setmalloctag(m, getcallerpc(&c)); 193 return m; 194 } 195 196 int 197 sendmsg(Msg *m) 198 { 199 int i, pad; 200 uchar *p; 201 ulong datalen, len, crc; 202 Conn *c; 203 204 datalen = m->wp - m->bp; 205 len = datalen + 5; 206 pad = 8 - len%8; 207 208 debug(DBG_PROTO, "sending %s len %lud\n", msgnames[m->type], datalen); 209 210 p = m->bp; 211 memmove(m->bp+4+pad+1, m->bp, datalen); /* slide data to correct position */ 212 213 PLONG(p, len); 214 p += 4; 215 216 if(m->c->cstate){ 217 for(i=0; i<pad; i++) 218 *p++ = fastrand(); 219 }else{ 220 memset(p, 0, pad); 221 p += pad; 222 } 223 224 *p++ = m->type; 225 226 /* data already in position */ 227 p += datalen; 228 229 crc = sum32(0, m->bp+4, pad+1+datalen); 230 PLONG(p, crc); 231 p += 4; 232 233 c = m->c; 234 qlock(c); 235 if(c->cstate) 236 c->cipher->encrypt(c->cstate, m->bp+4, len+pad); 237 238 if(write(c->fd[1], m->bp, p - m->bp) != p-m->bp){ 239 qunlock(c); 240 free(m); 241 return -1; 242 } 243 qunlock(c); 244 free(m); 245 return 0; 246 } 247 248 uchar 249 getbyte(Msg *m) 250 { 251 if(m->rp >= m->ep) 252 error(Edecode); 253 return *m->rp++; 254 } 255 256 ushort 257 getshort(Msg *m) 258 { 259 ushort x; 260 261 if(m->rp+2 > m->ep) 262 error(Edecode); 263 264 x = SHORT(m->rp); 265 m->rp += 2; 266 return x; 267 } 268 269 ulong 270 getlong(Msg *m) 271 { 272 ulong x; 273 274 if(m->rp+4 > m->ep) 275 error(Edecode); 276 277 x = LONG(m->rp); 278 m->rp += 4; 279 return x; 280 } 281 282 char* 283 getstring(Msg *m) 284 { 285 char *p; 286 ulong len; 287 288 /* overwrites length to make room for NUL */ 289 len = getlong(m); 290 if(m->rp+len > m->ep) 291 error(Edecode); 292 p = (char*)m->rp-1; 293 memmove(p, m->rp, len); 294 p[len] = '\0'; 295 return p; 296 } 297 298 void* 299 getbytes(Msg *m, int n) 300 { 301 uchar *p; 302 303 if(m->rp+n > m->ep) 304 error(Edecode); 305 p = m->rp; 306 m->rp += n; 307 return p; 308 } 309 310 mpint* 311 getmpint(Msg *m) 312 { 313 int n; 314 315 n = (getshort(m)+7)/8; /* getshort returns # bits */ 316 return betomp(getbytes(m, n), n, nil); 317 } 318 319 RSApub* 320 getRSApub(Msg *m) 321 { 322 RSApub *key; 323 324 getlong(m); 325 key = rsapuballoc(); 326 if(key == nil) 327 error(Ememory); 328 key->ek = getmpint(m); 329 key->n = getmpint(m); 330 setmalloctag(key, getcallerpc(&m)); 331 return key; 332 } 333 334 void 335 putbyte(Msg *m, uchar x) 336 { 337 if(m->wp >= m->ep) 338 error(Eencode); 339 *m->wp++ = x; 340 } 341 342 void 343 putshort(Msg *m, ushort x) 344 { 345 if(m->wp+2 > m->ep) 346 error(Eencode); 347 PSHORT(m->wp, x); 348 m->wp += 2; 349 } 350 351 void 352 putlong(Msg *m, ulong x) 353 { 354 if(m->wp+4 > m->ep) 355 error(Eencode); 356 PLONG(m->wp, x); 357 m->wp += 4; 358 } 359 360 void 361 putstring(Msg *m, char *s) 362 { 363 int len; 364 365 len = strlen(s); 366 putlong(m, len); 367 putbytes(m, s, len); 368 } 369 370 void 371 putbytes(Msg *m, void *a, long n) 372 { 373 if(m->wp+n > m->ep) 374 error(Eencode); 375 memmove(m->wp, a, n); 376 m->wp += n; 377 } 378 379 void 380 putmpint(Msg *m, mpint *b) 381 { 382 int bits, n; 383 384 bits = mpsignif(b); 385 putshort(m, bits); 386 n = (bits+7)/8; 387 if(m->wp+n > m->ep) 388 error(Eencode); 389 mptobe(b, m->wp, n, nil); 390 m->wp += n; 391 } 392 393 void 394 putRSApub(Msg *m, RSApub *key) 395 { 396 putlong(m, mpsignif(key->n)); 397 putmpint(m, key->ek); 398 putmpint(m, key->n); 399 } 400 401 static ulong crctab[256]; 402 403 static void 404 initsum32(void) 405 { 406 ulong crc, poly; 407 int i, j; 408 409 poly = 0xEDB88320; 410 for(i = 0; i < 256; i++){ 411 crc = i; 412 for(j = 0; j < 8; j++){ 413 if(crc & 1) 414 crc = (crc >> 1) ^ poly; 415 else 416 crc >>= 1; 417 } 418 crctab[i] = crc; 419 } 420 } 421 422 static ulong 423 sum32(ulong lcrc, void *buf, int n) 424 { 425 static int first=1; 426 uchar *s = buf; 427 ulong crc = lcrc; 428 429 if(first){ 430 first=0; 431 initsum32(); 432 } 433 while(n-- > 0) 434 crc = crctab[(crc^*s++)&0xff] ^ (crc>>8); 435 return crc; 436 } 437 438 mpint* 439 rsapad(mpint *b, int n) 440 { 441 int i, pad, nbuf; 442 uchar buf[2560]; 443 mpint *c; 444 445 if(n > sizeof buf) 446 error("buffer too small in rsapad"); 447 448 nbuf = (mpsignif(b)+7)/8; 449 pad = n - nbuf; 450 assert(pad >= 3); 451 mptobe(b, buf, nbuf, nil); 452 memmove(buf+pad, buf, nbuf); 453 454 buf[0] = 0; 455 buf[1] = 2; 456 for(i=2; i<pad-1; i++) 457 buf[i]=1+fastrand()%255; 458 buf[pad-1] = 0; 459 c = betomp(buf, n, nil); 460 memset(buf, 0, sizeof buf); 461 return c; 462 } 463 464 mpint* 465 rsaunpad(mpint *b) 466 { 467 int i, n; 468 uchar buf[2560]; 469 470 n = (mpsignif(b)+7)/8; 471 if(n > sizeof buf) 472 error("buffer too small in rsaunpad"); 473 mptobe(b, buf, n, nil); 474 475 /* the initial zero has been eaten by the betomp -> mptobe sequence */ 476 if(buf[0] != 2) 477 error("bad data in rsaunpad"); 478 for(i=1; i<n; i++) 479 if(buf[i]==0) 480 break; 481 return betomp(buf+i, n-i, nil); 482 } 483 484 void 485 mptoberjust(mpint *b, uchar *buf, int len) 486 { 487 int n; 488 489 n = mptobe(b, buf, len, nil); 490 assert(n >= 0); 491 if(n < len){ 492 len -= n; 493 memmove(buf+len, buf, n); 494 memset(buf, 0, len); 495 } 496 } 497 498 mpint* 499 rsaencryptbuf(RSApub *key, uchar *buf, int nbuf) 500 { 501 int n; 502 mpint *a, *b, *c; 503 504 n = (mpsignif(key->n)+7)/8; 505 a = betomp(buf, nbuf, nil); 506 b = rsapad(a, n); 507 mpfree(a); 508 c = rsaencrypt(key, b, nil); 509 mpfree(b); 510 return c; 511 } 512 513