1 #include <u.h>
2 #include <libc.h>
3 #include <oventi.h>
4 #include "session.h"
5
6 struct {
7 int version;
8 char *s;
9 } vtVersions[] = {
10 VtVersion02, "02",
11 0, 0,
12 };
13
14 static char EBigString[] = "string too long";
15 static char EBigPacket[] = "packet too long";
16 static char ENullString[] = "missing string";
17 static char EBadVersion[] = "bad format in version string";
18
19 static Packet *vtRPC(VtSession *z, int op, Packet *p);
20
21
22 VtSession *
vtAlloc(void)23 vtAlloc(void)
24 {
25 VtSession *z;
26
27 z = vtMemAllocZ(sizeof(VtSession));
28 z->lk = vtLockAlloc();
29 // z->inHash = vtSha1Alloc();
30 z->inLock = vtLockAlloc();
31 z->part = packetAlloc();
32 // z->outHash = vtSha1Alloc();
33 z->outLock = vtLockAlloc();
34 z->fd = -1;
35 z->uid = vtStrDup("anonymous");
36 z->sid = vtStrDup("anonymous");
37 return z;
38 }
39
40 void
vtReset(VtSession * z)41 vtReset(VtSession *z)
42 {
43 vtLock(z->lk);
44 z->cstate = VtStateAlloc;
45 if(z->fd >= 0){
46 vtFdClose(z->fd);
47 z->fd = -1;
48 }
49 vtUnlock(z->lk);
50 }
51
52 int
vtConnected(VtSession * z)53 vtConnected(VtSession *z)
54 {
55 return z->cstate == VtStateConnected;
56 }
57
58 void
vtDisconnect(VtSession * z,int error)59 vtDisconnect(VtSession *z, int error)
60 {
61 Packet *p;
62 uchar *b;
63
64 vtDebug(z, "vtDisconnect\n");
65 vtLock(z->lk);
66 if(z->cstate == VtStateConnected && !error && z->vtbl == nil) {
67 /* clean shutdown */
68 p = packetAlloc();
69 b = packetHeader(p, 2);
70 b[0] = VtQGoodbye;
71 b[1] = 0;
72 vtSendPacket(z, p);
73 }
74 if(z->fd >= 0)
75 vtFdClose(z->fd);
76 z->fd = -1;
77 z->cstate = VtStateClosed;
78 vtUnlock(z->lk);
79 }
80
81 void
vtClose(VtSession * z)82 vtClose(VtSession *z)
83 {
84 vtDisconnect(z, 0);
85 }
86
87 void
vtFree(VtSession * z)88 vtFree(VtSession *z)
89 {
90 if(z == nil)
91 return;
92 vtLockFree(z->lk);
93 vtSha1Free(z->inHash);
94 vtLockFree(z->inLock);
95 packetFree(z->part);
96 vtSha1Free(z->outHash);
97 vtLockFree(z->outLock);
98 vtMemFree(z->uid);
99 vtMemFree(z->sid);
100 vtMemFree(z->vtbl);
101
102 memset(z, 0, sizeof(VtSession));
103 z->fd = -1;
104
105 vtMemFree(z);
106 }
107
108 char *
vtGetUid(VtSession * s)109 vtGetUid(VtSession *s)
110 {
111 return s->uid;
112 }
113
114 char *
vtGetSid(VtSession * z)115 vtGetSid(VtSession *z)
116 {
117 return z->sid;
118 }
119
120 int
vtSetDebug(VtSession * z,int debug)121 vtSetDebug(VtSession *z, int debug)
122 {
123 int old;
124 vtLock(z->lk);
125 old = z->debug;
126 z->debug = debug;
127 vtUnlock(z->lk);
128 return old;
129 }
130
131 int
vtSetFd(VtSession * z,int fd)132 vtSetFd(VtSession *z, int fd)
133 {
134 vtLock(z->lk);
135 if(z->cstate != VtStateAlloc) {
136 vtSetError("bad state");
137 vtUnlock(z->lk);
138 return 0;
139 }
140 if(z->fd >= 0)
141 vtFdClose(z->fd);
142 z->fd = fd;
143 vtUnlock(z->lk);
144 return 1;
145 }
146
147 int
vtGetFd(VtSession * z)148 vtGetFd(VtSession *z)
149 {
150 return z->fd;
151 }
152
153 int
vtSetCryptoStrength(VtSession * z,int c)154 vtSetCryptoStrength(VtSession *z, int c)
155 {
156 if(z->cstate != VtStateAlloc) {
157 vtSetError("bad state");
158 return 0;
159 }
160 if(c != VtCryptoStrengthNone) {
161 vtSetError("not supported yet");
162 return 0;
163 }
164 return 1;
165 }
166
167 int
vtGetCryptoStrength(VtSession * s)168 vtGetCryptoStrength(VtSession *s)
169 {
170 return s->cryptoStrength;
171 }
172
173 int
vtSetCompression(VtSession * z,int fd)174 vtSetCompression(VtSession *z, int fd)
175 {
176 vtLock(z->lk);
177 if(z->cstate != VtStateAlloc) {
178 vtSetError("bad state");
179 vtUnlock(z->lk);
180 return 0;
181 }
182 z->fd = fd;
183 vtUnlock(z->lk);
184 return 1;
185 }
186
187 int
vtGetCompression(VtSession * s)188 vtGetCompression(VtSession *s)
189 {
190 return s->compression;
191 }
192
193 int
vtGetCrypto(VtSession * s)194 vtGetCrypto(VtSession *s)
195 {
196 return s->crypto;
197 }
198
199 int
vtGetCodec(VtSession * s)200 vtGetCodec(VtSession *s)
201 {
202 return s->codec;
203 }
204
205 char *
vtGetVersion(VtSession * z)206 vtGetVersion(VtSession *z)
207 {
208 int v, i;
209
210 v = z->version;
211 if(v == 0)
212 return "unknown";
213 for(i=0; vtVersions[i].version; i++)
214 if(vtVersions[i].version == v)
215 return vtVersions[i].s;
216 assert(0);
217 return 0;
218 }
219
220 /* hold z->inLock */
221 static int
vtVersionRead(VtSession * z,char * prefix,int * ret)222 vtVersionRead(VtSession *z, char *prefix, int *ret)
223 {
224 char c;
225 char buf[VtMaxStringSize];
226 char *q, *p, *pp;
227 int i;
228
229 q = prefix;
230 p = buf;
231 for(;;) {
232 if(p >= buf + sizeof(buf)) {
233 vtSetError(EBadVersion);
234 return 0;
235 }
236 if(!vtFdReadFully(z->fd, (uchar*)&c, 1))
237 return 0;
238 if(z->inHash)
239 vtSha1Update(z->inHash, (uchar*)&c, 1);
240 if(c == '\n') {
241 *p = 0;
242 break;
243 }
244 if(c < ' ' || *q && c != *q) {
245 vtSetError(EBadVersion);
246 return 0;
247 }
248 *p++ = c;
249 if(*q)
250 q++;
251 }
252
253 vtDebug(z, "version string in: %s\n", buf);
254
255 p = buf + strlen(prefix);
256 for(;;) {
257 for(pp=p; *pp && *pp != ':' && *pp != '-'; pp++)
258 ;
259 for(i=0; vtVersions[i].version; i++) {
260 if(strlen(vtVersions[i].s) != pp-p)
261 continue;
262 if(memcmp(vtVersions[i].s, p, pp-p) == 0) {
263 *ret = vtVersions[i].version;
264 return 1;
265 }
266 }
267 p = pp;
268 if(*p != ':')
269 return 0;
270 p++;
271 }
272 }
273
274 Packet*
vtRecvPacket(VtSession * z)275 vtRecvPacket(VtSession *z)
276 {
277 uchar buf[10], *b;
278 int n;
279 Packet *p;
280 int size, len;
281
282 if(z->cstate != VtStateConnected) {
283 vtSetError("session not connected");
284 return 0;
285 }
286
287 vtLock(z->inLock);
288 p = z->part;
289 /* get enough for head size */
290 size = packetSize(p);
291 while(size < 2) {
292 b = packetTrailer(p, MaxFragSize);
293 assert(b != nil);
294 n = vtFdRead(z->fd, b, MaxFragSize);
295 if(n <= 0)
296 goto Err;
297 size += n;
298 packetTrim(p, 0, size);
299 }
300
301 if(!packetConsume(p, buf, 2))
302 goto Err;
303 len = (buf[0] << 8) | buf[1];
304 size -= 2;
305
306 while(size < len) {
307 n = len - size;
308 if(n > MaxFragSize)
309 n = MaxFragSize;
310 b = packetTrailer(p, n);
311 if(!vtFdReadFully(z->fd, b, n))
312 goto Err;
313 size += n;
314 }
315 p = packetSplit(p, len);
316 vtUnlock(z->inLock);
317 return p;
318 Err:
319 vtUnlock(z->inLock);
320 return nil;
321 }
322
323 int
vtSendPacket(VtSession * z,Packet * p)324 vtSendPacket(VtSession *z, Packet *p)
325 {
326 IOchunk ioc;
327 int n;
328 uchar buf[2];
329
330 /* add framing */
331 n = packetSize(p);
332 if(n >= (1<<16)) {
333 vtSetError(EBigPacket);
334 packetFree(p);
335 return 0;
336 }
337 buf[0] = n>>8;
338 buf[1] = n;
339 packetPrefix(p, buf, 2);
340
341 for(;;) {
342 n = packetFragments(p, &ioc, 1, 0);
343 if(n == 0)
344 break;
345 if(!vtFdWrite(z->fd, ioc.addr, ioc.len)) {
346 packetFree(p);
347 return 0;
348 }
349 packetConsume(p, nil, n);
350 }
351 packetFree(p);
352 return 1;
353 }
354
355
356 int
vtGetString(Packet * p,char ** ret)357 vtGetString(Packet *p, char **ret)
358 {
359 uchar buf[2];
360 int n;
361 char *s;
362
363 if(!packetConsume(p, buf, 2))
364 return 0;
365 n = (buf[0]<<8) + buf[1];
366 if(n > VtMaxStringSize) {
367 vtSetError(EBigString);
368 return 0;
369 }
370 s = vtMemAlloc(n+1);
371 setmalloctag(s, getcallerpc(&p));
372 if(!packetConsume(p, (uchar*)s, n)) {
373 vtMemFree(s);
374 return 0;
375 }
376 s[n] = 0;
377 *ret = s;
378 return 1;
379 }
380
381 int
vtAddString(Packet * p,char * s)382 vtAddString(Packet *p, char *s)
383 {
384 uchar buf[2];
385 int n;
386
387 if(s == nil) {
388 vtSetError(ENullString);
389 return 0;
390 }
391 n = strlen(s);
392 if(n > VtMaxStringSize) {
393 vtSetError(EBigString);
394 return 0;
395 }
396 buf[0] = n>>8;
397 buf[1] = n;
398 packetAppend(p, buf, 2);
399 packetAppend(p, (uchar*)s, n);
400 return 1;
401 }
402
403 int
vtConnect(VtSession * z,char * password)404 vtConnect(VtSession *z, char *password)
405 {
406 char buf[VtMaxStringSize], *p, *ep, *prefix;
407 int i;
408
409 USED(password);
410 vtLock(z->lk);
411 if(z->cstate != VtStateAlloc) {
412 vtSetError("bad session state");
413 vtUnlock(z->lk);
414 return 0;
415 }
416 if(z->fd < 0){
417 vtSetError("%s", z->fderror);
418 vtUnlock(z->lk);
419 return 0;
420 }
421
422 /* be a little anal */
423 vtLock(z->inLock);
424 vtLock(z->outLock);
425
426 prefix = "venti-";
427 p = buf;
428 ep = buf + sizeof(buf);
429 p = seprint(p, ep, "%s", prefix);
430 p += strlen(p);
431 for(i=0; vtVersions[i].version; i++) {
432 if(i != 0)
433 *p++ = ':';
434 p = seprint(p, ep, "%s", vtVersions[i].s);
435 }
436 p = seprint(p, ep, "-libventi\n");
437 assert(p-buf < sizeof(buf));
438 if(z->outHash)
439 vtSha1Update(z->outHash, (uchar*)buf, p-buf);
440 if(!vtFdWrite(z->fd, (uchar*)buf, p-buf))
441 goto Err;
442
443 vtDebug(z, "version string out: %s", buf);
444
445 if(!vtVersionRead(z, prefix, &z->version))
446 goto Err;
447
448 vtDebug(z, "version = %d: %s\n", z->version, vtGetVersion(z));
449
450 vtUnlock(z->inLock);
451 vtUnlock(z->outLock);
452 z->cstate = VtStateConnected;
453 vtUnlock(z->lk);
454
455 if(z->vtbl)
456 return 1;
457
458 if(!vtHello(z))
459 goto Err;
460 return 1;
461 Err:
462 if(z->fd >= 0)
463 vtFdClose(z->fd);
464 z->fd = -1;
465 vtUnlock(z->inLock);
466 vtUnlock(z->outLock);
467 z->cstate = VtStateClosed;
468 vtUnlock(z->lk);
469 return 0;
470 }
471
472