xref: /plan9/sys/src/libventi/rpc.c (revision ec59a3ddbfceee0efe34584c2c9981a5e5ff1ec4)
1 #include <u.h>
2 #include <libc.h>
3 #include <venti.h>
4 #include "session.h"
5 
6 struct {
7 	int version;
8 	char *s;
9 } vtVersions[] = {
10 	VtVersion02, "02",
11 	0, 0,
12 };
13 
14 static char EBigString[] = "string too long";
15 static char EBigPacket[] = "packet too long";
16 static char ENullString[] = "missing string";
17 static char EBadVersion[] = "bad format in version string";
18 
19 static Packet *vtRPC(VtSession *z, int op, Packet *p);
20 
21 
22 VtSession *
23 vtAlloc(void)
24 {
25 	VtSession *z;
26 
27 	z = vtMemAllocZ(sizeof(VtSession));
28 	z->lk = vtLockAlloc();
29 //	z->inHash = vtSha1Alloc();
30 	z->inLock = vtLockAlloc();
31 	z->part = packetAlloc();
32 //	z->outHash = vtSha1Alloc();
33 	z->outLock = vtLockAlloc();
34 	z->fd = -1;
35 	z->uid = vtStrDup("anonymous");
36 	z->sid = vtStrDup("anonymous");
37 	return z;
38 }
39 
40 void
41 vtReset(VtSession *z)
42 {
43 	vtLock(z->lk);
44 	z->cstate = VtStateAlloc;
45 	if(z->fd >= 0){
46 		vtFdClose(z->fd);
47 		z->fd = -1;
48 	}
49 	vtUnlock(z->lk);
50 }
51 
52 int
53 vtConnected(VtSession *z)
54 {
55 	return z->cstate == VtStateConnected;
56 }
57 
58 void
59 vtDisconnect(VtSession *z, int error)
60 {
61 	Packet *p;
62 	uchar *b;
63 
64 vtDebug(z, "vtDisconnect\n");
65 	vtLock(z->lk);
66 	if(z->cstate == VtStateConnected && !error && z->vtbl == nil) {
67 		/* clean shutdown */
68 		p = packetAlloc();
69 		b = packetHeader(p, 2);
70 		b[0] = VtQGoodbye;
71 		b[1] = 0;
72 		vtSendPacket(z, p);
73 	}
74 	if(z->fd >= 0)
75 		vtFdClose(z->fd);
76 	z->fd = -1;
77 	z->cstate = VtStateClosed;
78 	vtUnlock(z->lk);
79 }
80 
81 void
82 vtClose(VtSession *z)
83 {
84 	vtDisconnect(z, 0);
85 }
86 
87 void
88 vtFree(VtSession *z)
89 {
90 	if(z == nil)
91 		return;
92 	vtLockFree(z->lk);
93 	vtSha1Free(z->inHash);
94 	vtLockFree(z->inLock);
95 	packetFree(z->part);
96 	vtSha1Free(z->outHash);
97 	vtLockFree(z->outLock);
98 	vtMemFree(z->uid);
99 	vtMemFree(z->sid);
100 	vtMemFree(z->vtbl);
101 
102 	memset(z, 0, sizeof(VtSession));
103 	z->fd = -1;
104 
105 	vtMemFree(z);
106 }
107 
108 char *
109 vtGetUid(VtSession *s)
110 {
111 	return s->uid;
112 }
113 
114 char *
115 vtGetSid(VtSession *z)
116 {
117 	return z->sid;
118 }
119 
120 int
121 vtSetDebug(VtSession *z, int debug)
122 {
123 	int old;
124 	vtLock(z->lk);
125 	old = z->debug;
126 	z->debug = debug;
127 	vtUnlock(z->lk);
128 	return old;
129 }
130 
131 int
132 vtSetFd(VtSession *z, int fd)
133 {
134 	vtLock(z->lk);
135 	if(z->cstate != VtStateAlloc) {
136 		vtSetError("bad state");
137 		vtUnlock(z->lk);
138 		return 0;
139 	}
140 	if(z->fd >= 0)
141 		vtFdClose(z->fd);
142 	z->fd = fd;
143 	vtUnlock(z->lk);
144 	return 1;
145 }
146 
147 int
148 vtGetFd(VtSession *z)
149 {
150 	return z->fd;
151 }
152 
153 int
154 vtSetCryptoStrength(VtSession *z, int c)
155 {
156 	if(z->cstate != VtStateAlloc) {
157 		vtSetError("bad state");
158 		return 0;
159 	}
160 	if(c != VtCryptoStrengthNone) {
161 		vtSetError("not supported yet");
162 		return 0;
163 	}
164 	return 1;
165 }
166 
167 int
168 vtGetCryptoStrength(VtSession *s)
169 {
170 	return s->cryptoStrength;
171 }
172 
173 int
174 vtSetCompression(VtSession *z, int fd)
175 {
176 	vtLock(z->lk);
177 	if(z->cstate != VtStateAlloc) {
178 		vtSetError("bad state");
179 		vtUnlock(z->lk);
180 		return 0;
181 	}
182 	z->fd = fd;
183 	vtUnlock(z->lk);
184 	return 1;
185 }
186 
187 int
188 vtGetCompression(VtSession *s)
189 {
190 	return s->compression;
191 }
192 
193 int
194 vtGetCrypto(VtSession *s)
195 {
196 	return s->crypto;
197 }
198 
199 int
200 vtGetCodec(VtSession *s)
201 {
202 	return s->codec;
203 }
204 
205 char *
206 vtGetVersion(VtSession *z)
207 {
208 	int v, i;
209 
210 	v = z->version;
211 	if(v == 0)
212 		return "unknown";
213 	for(i=0; vtVersions[i].version; i++)
214 		if(vtVersions[i].version == v)
215 			return vtVersions[i].s;
216 	assert(0);
217 	return 0;
218 }
219 
220 /* hold z->inLock */
221 static int
222 vtVersionRead(VtSession *z, char *prefix, int *ret)
223 {
224 	char c;
225 	char buf[VtMaxStringSize];
226 	char *q, *p, *pp;
227 	int i;
228 
229 	q = prefix;
230 	p = buf;
231 	for(;;) {
232 		if(p >= buf + sizeof(buf)) {
233 			vtSetError(EBadVersion);
234 			return 0;
235 		}
236 		if(!vtFdReadFully(z->fd, (uchar*)&c, 1))
237 			return 0;
238 		if(z->inHash)
239 			vtSha1Update(z->inHash, (uchar*)&c, 1);
240 		if(c == '\n') {
241 			*p = 0;
242 			break;
243 		}
244 		if(c < ' ' || c > 0x7f || *q && c != *q) {
245 			vtSetError(EBadVersion);
246 			return 0;
247 		}
248 		*p++ = c;
249 		if(*q)
250 			q++;
251 	}
252 
253 	vtDebug(z, "version string in: %s\n", buf);
254 
255 	p = buf + strlen(prefix);
256 	for(;;) {
257 		for(pp=p; *pp && *pp != ':'  && *pp != '-'; pp++)
258 			;
259 		for(i=0; vtVersions[i].version; i++) {
260 			if(strlen(vtVersions[i].s) != pp-p)
261 				continue;
262 			if(memcmp(vtVersions[i].s, p, pp-p) == 0) {
263 				*ret = vtVersions[i].version;
264 				return 1;
265 			}
266 		}
267 		p = pp;
268 		if(*p != ':')
269 			return 0;
270 		p++;
271 	}
272 }
273 
274 Packet*
275 vtRecvPacket(VtSession *z)
276 {
277 	uchar buf[10], *b;
278 	int n;
279 	Packet *p;
280 	int size, len;
281 
282 	if(z->cstate != VtStateConnected) {
283 		vtSetError("session not connected");
284 		return 0;
285 	}
286 
287 	vtLock(z->inLock);
288 	p = z->part;
289 	/* get enough for head size */
290 	size = packetSize(p);
291 	while(size < 2) {
292 		b = packetTrailer(p, MaxFragSize);
293 		assert(b != nil);
294 		n = vtFdRead(z->fd, b, MaxFragSize);
295 		if(n <= 0)
296 			goto Err;
297 		size += n;
298 		packetTrim(p, 0, size);
299 	}
300 
301 	if(!packetConsume(p, buf, 2))
302 		goto Err;
303 	len = (buf[0] << 8) | buf[1];
304 	size -= 2;
305 
306 	while(size < len) {
307 		n = len - size;
308 		if(n > MaxFragSize)
309 			n = MaxFragSize;
310 		b = packetTrailer(p, n);
311 		if(!vtFdReadFully(z->fd, b, n))
312 			goto Err;
313 		size += n;
314 	}
315 	p = packetSplit(p, len);
316 	vtUnlock(z->inLock);
317 	return p;
318 Err:
319 	vtUnlock(z->inLock);
320 	return nil;
321 }
322 
323 int
324 vtSendPacket(VtSession *z, Packet *p)
325 {
326 	IOchunk ioc;
327 	int n;
328 	uchar buf[2];
329 
330 	/* add framing */
331 	n = packetSize(p);
332 	if(n >= (1<<16)) {
333 		vtSetError(EBigPacket);
334 		packetFree(p);
335 		return 0;
336 	}
337 	buf[0] = n>>8;
338 	buf[1] = n;
339 	packetPrefix(p, buf, 2);
340 
341 	for(;;) {
342 		n = packetFragments(p, &ioc, 1, 0);
343 		if(n == 0)
344 			break;
345 		if(!vtFdWrite(z->fd, ioc.addr, ioc.len)) {
346 			packetFree(p);
347 			return 0;
348 		}
349 		packetConsume(p, nil, n);
350 	}
351 	packetFree(p);
352 	return 1;
353 }
354 
355 
356 int
357 vtGetString(Packet *p, char **ret)
358 {
359 	uchar buf[2];
360 	int n;
361 	char *s;
362 
363 	if(!packetConsume(p, buf, 2))
364 		return 0;
365 	n = (buf[0]<<8) + buf[1];
366 	if(n > VtMaxStringSize) {
367 		vtSetError(EBigString);
368 		return 0;
369 	}
370 	s = vtMemAlloc(n+1);
371 	if(!packetConsume(p, (uchar*)s, n)) {
372 		vtMemFree(s);
373 		return 0;
374 	}
375 	s[n] = 0;
376 	*ret = s;
377 	return 1;
378 }
379 
380 int
381 vtAddString(Packet *p, char *s)
382 {
383 	uchar buf[2];
384 	int n;
385 
386 	if(s == nil) {
387 		vtSetError(ENullString);
388 		return 0;
389 	}
390 	n = strlen(s);
391 	if(n > VtMaxStringSize) {
392 		vtSetError(EBigString);
393 		return 0;
394 	}
395 	buf[0] = n>>8;
396 	buf[1] = n;
397 	packetAppend(p, buf, 2);
398 	packetAppend(p, (uchar*)s, n);
399 	return 1;
400 }
401 
402 int
403 vtConnect(VtSession *z, char *password)
404 {
405 	char buf[VtMaxStringSize], *p, *ep, *prefix;
406 	int i;
407 
408 	USED(password);
409 	vtLock(z->lk);
410 	if(z->cstate != VtStateAlloc) {
411 		vtSetError("bad session state");
412 		vtUnlock(z->lk);
413 		return 0;
414 	}
415 	if(z->fd < 0){
416 		vtSetError("%s", z->fderror);
417 		vtUnlock(z->lk);
418 		return 0;
419 	}
420 
421 	/* be a little anal */
422 	vtLock(z->inLock);
423 	vtLock(z->outLock);
424 
425 	prefix = "venti-";
426 	p = buf;
427 	ep = buf + sizeof(buf);
428 	p = seprint(p, ep, "%s", prefix);
429 	p += strlen(p);
430 	for(i=0; vtVersions[i].version; i++) {
431 		if(i != 0)
432 			*p++ = ':';
433 		p = seprint(p, ep, "%s", vtVersions[i].s);
434 	}
435 	p = seprint(p, ep, "-libventi\n");
436 	assert(p-buf < sizeof(buf));
437 	if(z->outHash)
438 		vtSha1Update(z->outHash, (uchar*)buf, p-buf);
439 	if(!vtFdWrite(z->fd, (uchar*)buf, p-buf))
440 		goto Err;
441 
442 	vtDebug(z, "version string out: %s", buf);
443 
444 	if(!vtVersionRead(z, prefix, &z->version))
445 		goto Err;
446 
447 	vtDebug(z, "version = %d: %s\n", z->version, vtGetVersion(z));
448 
449 	vtUnlock(z->inLock);
450 	vtUnlock(z->outLock);
451 	z->cstate = VtStateConnected;
452 	vtUnlock(z->lk);
453 
454 	if(z->vtbl)
455 		return 1;
456 
457 	if(!vtHello(z))
458 		goto Err;
459 	return 1;
460 Err:
461 	if(z->fd >= 0)
462 		vtFdClose(z->fd);
463 	z->fd = -1;
464 	vtUnlock(z->inLock);
465 	vtUnlock(z->outLock);
466 	z->cstate = VtStateClosed;
467 	vtUnlock(z->lk);
468 	return 0;
469 }
470 
471