1 #include <u.h> 2 #include <libc.h> 3 #include <venti.h> 4 #include "session.h" 5 6 struct { 7 int version; 8 char *s; 9 } vtVersions[] = { 10 VtVersion02, "02", 11 0, 0, 12 }; 13 14 static char EBigString[] = "string too long"; 15 static char EBigPacket[] = "packet too long"; 16 static char ENullString[] = "missing string"; 17 static char EBadVersion[] = "bad format in version string"; 18 19 static Packet *vtRPC(VtSession *z, int op, Packet *p); 20 21 22 VtSession * 23 vtAlloc(void) 24 { 25 VtSession *z; 26 27 z = vtMemAllocZ(sizeof(VtSession)); 28 z->lk = vtLockAlloc(); 29 // z->inHash = vtSha1Alloc(); 30 z->inLock = vtLockAlloc(); 31 z->part = packetAlloc(); 32 // z->outHash = vtSha1Alloc(); 33 z->outLock = vtLockAlloc(); 34 z->fd = -1; 35 z->uid = vtStrDup("anonymous"); 36 z->sid = vtStrDup("anonymous"); 37 return z; 38 } 39 40 void 41 vtReset(VtSession *z) 42 { 43 vtLock(z->lk); 44 z->cstate = VtStateAlloc; 45 if(z->fd >= 0){ 46 vtFdClose(z->fd); 47 z->fd = -1; 48 } 49 vtUnlock(z->lk); 50 } 51 52 int 53 vtConnected(VtSession *z) 54 { 55 return z->cstate == VtStateConnected; 56 } 57 58 void 59 vtDisconnect(VtSession *z, int error) 60 { 61 Packet *p; 62 uchar *b; 63 64 vtDebug(z, "vtDisconnect\n"); 65 vtLock(z->lk); 66 if(z->cstate == VtStateConnected && !error && z->vtbl == nil) { 67 /* clean shutdown */ 68 p = packetAlloc(); 69 b = packetHeader(p, 2); 70 b[0] = VtQGoodbye; 71 b[1] = 0; 72 vtSendPacket(z, p); 73 } 74 if(z->fd >= 0) 75 vtFdClose(z->fd); 76 z->fd = -1; 77 z->cstate = VtStateClosed; 78 vtUnlock(z->lk); 79 } 80 81 void 82 vtClose(VtSession *z) 83 { 84 vtDisconnect(z, 0); 85 } 86 87 void 88 vtFree(VtSession *z) 89 { 90 if(z == nil) 91 return; 92 vtLockFree(z->lk); 93 vtSha1Free(z->inHash); 94 vtLockFree(z->inLock); 95 packetFree(z->part); 96 vtSha1Free(z->outHash); 97 vtLockFree(z->outLock); 98 vtMemFree(z->uid); 99 vtMemFree(z->sid); 100 vtMemFree(z->vtbl); 101 102 memset(z, 0, sizeof(VtSession)); 103 z->fd = -1; 104 105 vtMemFree(z); 106 } 107 108 char * 109 vtGetUid(VtSession *s) 110 { 111 return s->uid; 112 } 113 114 char * 115 vtGetSid(VtSession *z) 116 { 117 return z->sid; 118 } 119 120 int 121 vtSetDebug(VtSession *z, int debug) 122 { 123 int old; 124 vtLock(z->lk); 125 old = z->debug; 126 z->debug = debug; 127 vtUnlock(z->lk); 128 return old; 129 } 130 131 int 132 vtSetFd(VtSession *z, int fd) 133 { 134 vtLock(z->lk); 135 if(z->cstate != VtStateAlloc) { 136 vtSetError("bad state"); 137 vtUnlock(z->lk); 138 return 0; 139 } 140 if(z->fd >= 0) 141 vtFdClose(z->fd); 142 z->fd = fd; 143 vtUnlock(z->lk); 144 return 1; 145 } 146 147 int 148 vtGetFd(VtSession *z) 149 { 150 return z->fd; 151 } 152 153 int 154 vtSetCryptoStrength(VtSession *z, int c) 155 { 156 if(z->cstate != VtStateAlloc) { 157 vtSetError("bad state"); 158 return 0; 159 } 160 if(c != VtCryptoStrengthNone) { 161 vtSetError("not supported yet"); 162 return 0; 163 } 164 return 1; 165 } 166 167 int 168 vtGetCryptoStrength(VtSession *s) 169 { 170 return s->cryptoStrength; 171 } 172 173 int 174 vtSetCompression(VtSession *z, int fd) 175 { 176 vtLock(z->lk); 177 if(z->cstate != VtStateAlloc) { 178 vtSetError("bad state"); 179 vtUnlock(z->lk); 180 return 0; 181 } 182 z->fd = fd; 183 vtUnlock(z->lk); 184 return 1; 185 } 186 187 int 188 vtGetCompression(VtSession *s) 189 { 190 return s->compression; 191 } 192 193 int 194 vtGetCrypto(VtSession *s) 195 { 196 return s->crypto; 197 } 198 199 int 200 vtGetCodec(VtSession *s) 201 { 202 return s->codec; 203 } 204 205 char * 206 vtGetVersion(VtSession *z) 207 { 208 int v, i; 209 210 v = z->version; 211 if(v == 0) 212 return "unknown"; 213 for(i=0; vtVersions[i].version; i++) 214 if(vtVersions[i].version == v) 215 return vtVersions[i].s; 216 assert(0); 217 return 0; 218 } 219 220 /* hold z->inLock */ 221 static int 222 vtVersionRead(VtSession *z, char *prefix, int *ret) 223 { 224 char c; 225 char buf[VtMaxStringSize]; 226 char *q, *p, *pp; 227 int i; 228 229 q = prefix; 230 p = buf; 231 for(;;) { 232 if(p >= buf + sizeof(buf)) { 233 vtSetError(EBadVersion); 234 return 0; 235 } 236 if(!vtFdReadFully(z->fd, (uchar*)&c, 1)) 237 return 0; 238 if(z->inHash) 239 vtSha1Update(z->inHash, (uchar*)&c, 1); 240 if(c == '\n') { 241 *p = 0; 242 break; 243 } 244 if(c < ' ' || c > 0x7f || *q && c != *q) { 245 vtSetError(EBadVersion); 246 return 0; 247 } 248 *p++ = c; 249 if(*q) 250 q++; 251 } 252 253 vtDebug(z, "version string in: %s\n", buf); 254 255 p = buf + strlen(prefix); 256 for(;;) { 257 for(pp=p; *pp && *pp != ':' && *pp != '-'; pp++) 258 ; 259 for(i=0; vtVersions[i].version; i++) { 260 if(strlen(vtVersions[i].s) != pp-p) 261 continue; 262 if(memcmp(vtVersions[i].s, p, pp-p) == 0) { 263 *ret = vtVersions[i].version; 264 return 1; 265 } 266 } 267 p = pp; 268 if(*p != ':') 269 return 0; 270 p++; 271 } 272 } 273 274 Packet* 275 vtRecvPacket(VtSession *z) 276 { 277 uchar buf[10], *b; 278 int n; 279 Packet *p; 280 int size, len; 281 282 if(z->cstate != VtStateConnected) { 283 vtSetError("session not connected"); 284 return 0; 285 } 286 287 vtLock(z->inLock); 288 p = z->part; 289 /* get enough for head size */ 290 size = packetSize(p); 291 while(size < 2) { 292 b = packetTrailer(p, MaxFragSize); 293 assert(b != nil); 294 n = vtFdRead(z->fd, b, MaxFragSize); 295 if(n <= 0) 296 goto Err; 297 size += n; 298 packetTrim(p, 0, size); 299 } 300 301 if(!packetConsume(p, buf, 2)) 302 goto Err; 303 len = (buf[0] << 8) | buf[1]; 304 size -= 2; 305 306 while(size < len) { 307 n = len - size; 308 if(n > MaxFragSize) 309 n = MaxFragSize; 310 b = packetTrailer(p, n); 311 if(!vtFdReadFully(z->fd, b, n)) 312 goto Err; 313 size += n; 314 } 315 p = packetSplit(p, len); 316 vtUnlock(z->inLock); 317 return p; 318 Err: 319 vtUnlock(z->inLock); 320 return nil; 321 } 322 323 int 324 vtSendPacket(VtSession *z, Packet *p) 325 { 326 IOchunk ioc; 327 int n; 328 uchar buf[2]; 329 330 /* add framing */ 331 n = packetSize(p); 332 if(n >= (1<<16)) { 333 vtSetError(EBigPacket); 334 packetFree(p); 335 return 0; 336 } 337 buf[0] = n>>8; 338 buf[1] = n; 339 packetPrefix(p, buf, 2); 340 341 for(;;) { 342 n = packetFragments(p, &ioc, 1, 0); 343 if(n == 0) 344 break; 345 if(!vtFdWrite(z->fd, ioc.addr, ioc.len)) { 346 packetFree(p); 347 return 0; 348 } 349 packetConsume(p, nil, n); 350 } 351 packetFree(p); 352 return 1; 353 } 354 355 356 int 357 vtGetString(Packet *p, char **ret) 358 { 359 uchar buf[2]; 360 int n; 361 char *s; 362 363 if(!packetConsume(p, buf, 2)) 364 return 0; 365 n = (buf[0]<<8) + buf[1]; 366 if(n > VtMaxStringSize) { 367 vtSetError(EBigString); 368 return 0; 369 } 370 s = vtMemAlloc(n+1); 371 if(!packetConsume(p, (uchar*)s, n)) { 372 vtMemFree(s); 373 return 0; 374 } 375 s[n] = 0; 376 *ret = s; 377 return 1; 378 } 379 380 int 381 vtAddString(Packet *p, char *s) 382 { 383 uchar buf[2]; 384 int n; 385 386 if(s == nil) { 387 vtSetError(ENullString); 388 return 0; 389 } 390 n = strlen(s); 391 if(n > VtMaxStringSize) { 392 vtSetError(EBigString); 393 return 0; 394 } 395 buf[0] = n>>8; 396 buf[1] = n; 397 packetAppend(p, buf, 2); 398 packetAppend(p, (uchar*)s, n); 399 return 1; 400 } 401 402 int 403 vtConnect(VtSession *z, char *password) 404 { 405 char buf[VtMaxStringSize], *p, *ep, *prefix; 406 int i; 407 408 USED(password); 409 vtLock(z->lk); 410 if(z->cstate != VtStateAlloc) { 411 vtSetError("bad session state"); 412 vtUnlock(z->lk); 413 return 0; 414 } 415 if(z->fd < 0){ 416 vtSetError("%s", z->fderror); 417 vtUnlock(z->lk); 418 return 0; 419 } 420 421 /* be a little anal */ 422 vtLock(z->inLock); 423 vtLock(z->outLock); 424 425 prefix = "venti-"; 426 p = buf; 427 ep = buf + sizeof(buf); 428 p = seprint(p, ep, "%s", prefix); 429 p += strlen(p); 430 for(i=0; vtVersions[i].version; i++) { 431 if(i != 0) 432 *p++ = ':'; 433 p = seprint(p, ep, "%s", vtVersions[i].s); 434 } 435 p = seprint(p, ep, "-libventi\n"); 436 assert(p-buf < sizeof(buf)); 437 if(z->outHash) 438 vtSha1Update(z->outHash, (uchar*)buf, p-buf); 439 if(!vtFdWrite(z->fd, (uchar*)buf, p-buf)) 440 goto Err; 441 442 vtDebug(z, "version string out: %s", buf); 443 444 if(!vtVersionRead(z, prefix, &z->version)) 445 goto Err; 446 447 vtDebug(z, "version = %d: %s\n", z->version, vtGetVersion(z)); 448 449 vtUnlock(z->inLock); 450 vtUnlock(z->outLock); 451 z->cstate = VtStateConnected; 452 vtUnlock(z->lk); 453 454 if(z->vtbl) 455 return 1; 456 457 if(!vtHello(z)) 458 goto Err; 459 return 1; 460 Err: 461 if(z->fd >= 0) 462 vtFdClose(z->fd); 463 z->fd = -1; 464 vtUnlock(z->inLock); 465 vtUnlock(z->outLock); 466 z->cstate = VtStateClosed; 467 vtUnlock(z->lk); 468 return 0; 469 } 470 471