xref: /plan9/sys/src/liboventi/rpc.c (revision 225077b0bf393489f69f6689df234a9b945497b7)
1 #include <u.h>
2 #include <libc.h>
3 #include <oventi.h>
4 #include "session.h"
5 
6 struct {
7 	int version;
8 	char *s;
9 } vtVersions[] = {
10 	VtVersion02, "02",
11 	0, 0,
12 };
13 
14 static char EBigString[] = "string too long";
15 static char EBigPacket[] = "packet too long";
16 static char ENullString[] = "missing string";
17 static char EBadVersion[] = "bad format in version string";
18 
19 static Packet *vtRPC(VtSession *z, int op, Packet *p);
20 
21 
22 VtSession *
vtAlloc(void)23 vtAlloc(void)
24 {
25 	VtSession *z;
26 
27 	z = vtMemAllocZ(sizeof(VtSession));
28 	z->lk = vtLockAlloc();
29 //	z->inHash = vtSha1Alloc();
30 	z->inLock = vtLockAlloc();
31 	z->part = packetAlloc();
32 //	z->outHash = vtSha1Alloc();
33 	z->outLock = vtLockAlloc();
34 	z->fd = -1;
35 	z->uid = vtStrDup("anonymous");
36 	z->sid = vtStrDup("anonymous");
37 	return z;
38 }
39 
40 void
vtReset(VtSession * z)41 vtReset(VtSession *z)
42 {
43 	vtLock(z->lk);
44 	z->cstate = VtStateAlloc;
45 	if(z->fd >= 0){
46 		vtFdClose(z->fd);
47 		z->fd = -1;
48 	}
49 	vtUnlock(z->lk);
50 }
51 
52 int
vtConnected(VtSession * z)53 vtConnected(VtSession *z)
54 {
55 	return z->cstate == VtStateConnected;
56 }
57 
58 void
vtDisconnect(VtSession * z,int error)59 vtDisconnect(VtSession *z, int error)
60 {
61 	Packet *p;
62 	uchar *b;
63 
64 vtDebug(z, "vtDisconnect\n");
65 	vtLock(z->lk);
66 	if(z->cstate == VtStateConnected && !error && z->vtbl == nil) {
67 		/* clean shutdown */
68 		p = packetAlloc();
69 		b = packetHeader(p, 2);
70 		b[0] = VtQGoodbye;
71 		b[1] = 0;
72 		vtSendPacket(z, p);
73 	}
74 	if(z->fd >= 0)
75 		vtFdClose(z->fd);
76 	z->fd = -1;
77 	z->cstate = VtStateClosed;
78 	vtUnlock(z->lk);
79 }
80 
81 void
vtClose(VtSession * z)82 vtClose(VtSession *z)
83 {
84 	vtDisconnect(z, 0);
85 }
86 
87 void
vtFree(VtSession * z)88 vtFree(VtSession *z)
89 {
90 	if(z == nil)
91 		return;
92 	vtLockFree(z->lk);
93 	vtSha1Free(z->inHash);
94 	vtLockFree(z->inLock);
95 	packetFree(z->part);
96 	vtSha1Free(z->outHash);
97 	vtLockFree(z->outLock);
98 	vtMemFree(z->uid);
99 	vtMemFree(z->sid);
100 	vtMemFree(z->vtbl);
101 
102 	memset(z, 0, sizeof(VtSession));
103 	z->fd = -1;
104 
105 	vtMemFree(z);
106 }
107 
108 char *
vtGetUid(VtSession * s)109 vtGetUid(VtSession *s)
110 {
111 	return s->uid;
112 }
113 
114 char *
vtGetSid(VtSession * z)115 vtGetSid(VtSession *z)
116 {
117 	return z->sid;
118 }
119 
120 int
vtSetDebug(VtSession * z,int debug)121 vtSetDebug(VtSession *z, int debug)
122 {
123 	int old;
124 	vtLock(z->lk);
125 	old = z->debug;
126 	z->debug = debug;
127 	vtUnlock(z->lk);
128 	return old;
129 }
130 
131 int
vtSetFd(VtSession * z,int fd)132 vtSetFd(VtSession *z, int fd)
133 {
134 	vtLock(z->lk);
135 	if(z->cstate != VtStateAlloc) {
136 		vtSetError("bad state");
137 		vtUnlock(z->lk);
138 		return 0;
139 	}
140 	if(z->fd >= 0)
141 		vtFdClose(z->fd);
142 	z->fd = fd;
143 	vtUnlock(z->lk);
144 	return 1;
145 }
146 
147 int
vtGetFd(VtSession * z)148 vtGetFd(VtSession *z)
149 {
150 	return z->fd;
151 }
152 
153 int
vtSetCryptoStrength(VtSession * z,int c)154 vtSetCryptoStrength(VtSession *z, int c)
155 {
156 	if(z->cstate != VtStateAlloc) {
157 		vtSetError("bad state");
158 		return 0;
159 	}
160 	if(c != VtCryptoStrengthNone) {
161 		vtSetError("not supported yet");
162 		return 0;
163 	}
164 	return 1;
165 }
166 
167 int
vtGetCryptoStrength(VtSession * s)168 vtGetCryptoStrength(VtSession *s)
169 {
170 	return s->cryptoStrength;
171 }
172 
173 int
vtSetCompression(VtSession * z,int fd)174 vtSetCompression(VtSession *z, int fd)
175 {
176 	vtLock(z->lk);
177 	if(z->cstate != VtStateAlloc) {
178 		vtSetError("bad state");
179 		vtUnlock(z->lk);
180 		return 0;
181 	}
182 	z->fd = fd;
183 	vtUnlock(z->lk);
184 	return 1;
185 }
186 
187 int
vtGetCompression(VtSession * s)188 vtGetCompression(VtSession *s)
189 {
190 	return s->compression;
191 }
192 
193 int
vtGetCrypto(VtSession * s)194 vtGetCrypto(VtSession *s)
195 {
196 	return s->crypto;
197 }
198 
199 int
vtGetCodec(VtSession * s)200 vtGetCodec(VtSession *s)
201 {
202 	return s->codec;
203 }
204 
205 char *
vtGetVersion(VtSession * z)206 vtGetVersion(VtSession *z)
207 {
208 	int v, i;
209 
210 	v = z->version;
211 	if(v == 0)
212 		return "unknown";
213 	for(i=0; vtVersions[i].version; i++)
214 		if(vtVersions[i].version == v)
215 			return vtVersions[i].s;
216 	assert(0);
217 	return 0;
218 }
219 
220 /* hold z->inLock */
221 static int
vtVersionRead(VtSession * z,char * prefix,int * ret)222 vtVersionRead(VtSession *z, char *prefix, int *ret)
223 {
224 	char c;
225 	char buf[VtMaxStringSize];
226 	char *q, *p, *pp;
227 	int i;
228 
229 	q = prefix;
230 	p = buf;
231 	for(;;) {
232 		if(p >= buf + sizeof(buf)) {
233 			vtSetError(EBadVersion);
234 			return 0;
235 		}
236 		if(!vtFdReadFully(z->fd, (uchar*)&c, 1))
237 			return 0;
238 		if(z->inHash)
239 			vtSha1Update(z->inHash, (uchar*)&c, 1);
240 		if(c == '\n') {
241 			*p = 0;
242 			break;
243 		}
244 		if(c < ' ' || *q && c != *q) {
245 			vtSetError(EBadVersion);
246 			return 0;
247 		}
248 		*p++ = c;
249 		if(*q)
250 			q++;
251 	}
252 
253 	vtDebug(z, "version string in: %s\n", buf);
254 
255 	p = buf + strlen(prefix);
256 	for(;;) {
257 		for(pp=p; *pp && *pp != ':'  && *pp != '-'; pp++)
258 			;
259 		for(i=0; vtVersions[i].version; i++) {
260 			if(strlen(vtVersions[i].s) != pp-p)
261 				continue;
262 			if(memcmp(vtVersions[i].s, p, pp-p) == 0) {
263 				*ret = vtVersions[i].version;
264 				return 1;
265 			}
266 		}
267 		p = pp;
268 		if(*p != ':')
269 			return 0;
270 		p++;
271 	}
272 }
273 
274 Packet*
vtRecvPacket(VtSession * z)275 vtRecvPacket(VtSession *z)
276 {
277 	uchar buf[10], *b;
278 	int n;
279 	Packet *p;
280 	int size, len;
281 
282 	if(z->cstate != VtStateConnected) {
283 		vtSetError("session not connected");
284 		return 0;
285 	}
286 
287 	vtLock(z->inLock);
288 	p = z->part;
289 	/* get enough for head size */
290 	size = packetSize(p);
291 	while(size < 2) {
292 		b = packetTrailer(p, MaxFragSize);
293 		assert(b != nil);
294 		n = vtFdRead(z->fd, b, MaxFragSize);
295 		if(n <= 0)
296 			goto Err;
297 		size += n;
298 		packetTrim(p, 0, size);
299 	}
300 
301 	if(!packetConsume(p, buf, 2))
302 		goto Err;
303 	len = (buf[0] << 8) | buf[1];
304 	size -= 2;
305 
306 	while(size < len) {
307 		n = len - size;
308 		if(n > MaxFragSize)
309 			n = MaxFragSize;
310 		b = packetTrailer(p, n);
311 		if(!vtFdReadFully(z->fd, b, n))
312 			goto Err;
313 		size += n;
314 	}
315 	p = packetSplit(p, len);
316 	vtUnlock(z->inLock);
317 	return p;
318 Err:
319 	vtUnlock(z->inLock);
320 	return nil;
321 }
322 
323 int
vtSendPacket(VtSession * z,Packet * p)324 vtSendPacket(VtSession *z, Packet *p)
325 {
326 	IOchunk ioc;
327 	int n;
328 	uchar buf[2];
329 
330 	/* add framing */
331 	n = packetSize(p);
332 	if(n >= (1<<16)) {
333 		vtSetError(EBigPacket);
334 		packetFree(p);
335 		return 0;
336 	}
337 	buf[0] = n>>8;
338 	buf[1] = n;
339 	packetPrefix(p, buf, 2);
340 
341 	for(;;) {
342 		n = packetFragments(p, &ioc, 1, 0);
343 		if(n == 0)
344 			break;
345 		if(!vtFdWrite(z->fd, ioc.addr, ioc.len)) {
346 			packetFree(p);
347 			return 0;
348 		}
349 		packetConsume(p, nil, n);
350 	}
351 	packetFree(p);
352 	return 1;
353 }
354 
355 
356 int
vtGetString(Packet * p,char ** ret)357 vtGetString(Packet *p, char **ret)
358 {
359 	uchar buf[2];
360 	int n;
361 	char *s;
362 
363 	if(!packetConsume(p, buf, 2))
364 		return 0;
365 	n = (buf[0]<<8) + buf[1];
366 	if(n > VtMaxStringSize) {
367 		vtSetError(EBigString);
368 		return 0;
369 	}
370 	s = vtMemAlloc(n+1);
371 	setmalloctag(s, getcallerpc(&p));
372 	if(!packetConsume(p, (uchar*)s, n)) {
373 		vtMemFree(s);
374 		return 0;
375 	}
376 	s[n] = 0;
377 	*ret = s;
378 	return 1;
379 }
380 
381 int
vtAddString(Packet * p,char * s)382 vtAddString(Packet *p, char *s)
383 {
384 	uchar buf[2];
385 	int n;
386 
387 	if(s == nil) {
388 		vtSetError(ENullString);
389 		return 0;
390 	}
391 	n = strlen(s);
392 	if(n > VtMaxStringSize) {
393 		vtSetError(EBigString);
394 		return 0;
395 	}
396 	buf[0] = n>>8;
397 	buf[1] = n;
398 	packetAppend(p, buf, 2);
399 	packetAppend(p, (uchar*)s, n);
400 	return 1;
401 }
402 
403 int
vtConnect(VtSession * z,char * password)404 vtConnect(VtSession *z, char *password)
405 {
406 	char buf[VtMaxStringSize], *p, *ep, *prefix;
407 	int i;
408 
409 	USED(password);
410 	vtLock(z->lk);
411 	if(z->cstate != VtStateAlloc) {
412 		vtSetError("bad session state");
413 		vtUnlock(z->lk);
414 		return 0;
415 	}
416 	if(z->fd < 0){
417 		vtSetError("%s", z->fderror);
418 		vtUnlock(z->lk);
419 		return 0;
420 	}
421 
422 	/* be a little anal */
423 	vtLock(z->inLock);
424 	vtLock(z->outLock);
425 
426 	prefix = "venti-";
427 	p = buf;
428 	ep = buf + sizeof(buf);
429 	p = seprint(p, ep, "%s", prefix);
430 	p += strlen(p);
431 	for(i=0; vtVersions[i].version; i++) {
432 		if(i != 0)
433 			*p++ = ':';
434 		p = seprint(p, ep, "%s", vtVersions[i].s);
435 	}
436 	p = seprint(p, ep, "-libventi\n");
437 	assert(p-buf < sizeof(buf));
438 	if(z->outHash)
439 		vtSha1Update(z->outHash, (uchar*)buf, p-buf);
440 	if(!vtFdWrite(z->fd, (uchar*)buf, p-buf))
441 		goto Err;
442 
443 	vtDebug(z, "version string out: %s", buf);
444 
445 	if(!vtVersionRead(z, prefix, &z->version))
446 		goto Err;
447 
448 	vtDebug(z, "version = %d: %s\n", z->version, vtGetVersion(z));
449 
450 	vtUnlock(z->inLock);
451 	vtUnlock(z->outLock);
452 	z->cstate = VtStateConnected;
453 	vtUnlock(z->lk);
454 
455 	if(z->vtbl)
456 		return 1;
457 
458 	if(!vtHello(z))
459 		goto Err;
460 	return 1;
461 Err:
462 	if(z->fd >= 0)
463 		vtFdClose(z->fd);
464 	z->fd = -1;
465 	vtUnlock(z->inLock);
466 	vtUnlock(z->outLock);
467 	z->cstate = VtStateClosed;
468 	vtUnlock(z->lk);
469 	return 0;
470 }
471 
472