xref: /plan9/sys/src/cmd/aan.c (revision 58da3067adcdccaaa043d0bfde28ba83b7ced07d)
1 #include <u.h>
2 #include <libc.h>
3 #include <auth.h>
4 #include <fcall.h>
5 #include <thread.h>
6 
7 #define NS(x)		((vlong)x)
8 #define US(x)		(NS(x) * 1000LL)
9 #define MS(x)		(US(x) * 1000LL)
10 #define S(x)		(MS(x) * 1000LL)
11 
12 #define LOGNAME	"aan"
13 
14 enum {
15 	Synctime = S(8),
16 	Nbuf = 10,
17 	K = 1024,
18 	Bufsize = 8 * K,
19 	Stacksize = 8 * K,
20 	Timer = 0,		/* Alt channels. */
21 	Unsent = 1,
22 	Maxto = 24 * 3600,	/* A full day to reconnect. */
23 	Hdrsz = 12,
24 };
25 
26 typedef struct Endpoints Endpoints;
27 struct Endpoints {
28 	char	*lsys;
29 	char	*lserv;
30 	char	*rsys;
31 	char	*rserv;
32 };
33 
34 typedef struct {
35 	ulong	nb;		/* Number of data bytes in this message */
36 	ulong	msg;		/* Message number */
37 	ulong	acked;		/* Number of messages acked */
38 } Hdr;
39 
40 typedef struct {
41 	Hdr	hdr;
42 	uchar	buf[Bufsize];
43 } Buf;
44 
45 static char	*Logname = LOGNAME;
46 static int	client;
47 static int	debug;
48 static char	*devdir;
49 static char	*dialstring;
50 static int	done;
51 static int	inmsg;
52 static int	maxto = Maxto;
53 static int	netfd;
54 
55 static Channel	*empty;
56 static Channel	*unacked;
57 static Channel	*unsent;
58 
59 static Alt a[] = {
60 	/*	c	v	 op   */
61 	{ 	nil,	nil,	CHANRCV	},	/* timer */
62 	{	nil,	nil,	CHANRCV	},	/* unsent */
63 	{ 	nil,	nil,	CHANEND	},
64 };
65 
66 static void	dmessage(int, char *, ...);
67 static void	freeendpoints(Endpoints *);
68 static void	fromclient(void*);
69 static void	fromnet(void*);
70 static Endpoints *getendpoints(char *);
71 static void	packhdr(Hdr *, uchar *);
72 static void	reconnect(void);
73 static void	showmsg(int, char *, Buf *);
74 static void	synchronize(void);
75 static void	timerproc(void *);
76 static void	unpackhdr(Hdr *, uchar *);
77 static int	writen(int, uchar *, int);
78 
79 static void
usage(void)80 usage(void)
81 {
82 	fprint(2, "Usage: %s [-cd] [-m maxto] dialstring|netdir\n", argv0);
83 	threadexitsall("usage");
84 }
85 
86 static int
catch(void *,char * s)87 catch(void *, char *s)
88 {
89 	if (strstr(s, "alarm") != nil) {
90 		syslog(0, Logname, "Timed out waiting for client on %s, exiting...",
91 			   devdir);
92 		threadexitsall(nil);
93 	}
94 	return 0;
95 }
96 
97 void
threadmain(int argc,char ** argv)98 threadmain(int argc, char **argv)
99 {
100 	int i, fd, failed, delta;
101 	vlong synctime, now;
102 	char *p;
103 	uchar buf[Hdrsz];
104 	Buf *b, *eb;
105 	Channel *timer;
106 	Hdr hdr;
107 
108 	ARGBEGIN {
109 	case 'c':
110 		client++;
111 		break;
112 	case 'd':
113 		debug++;
114 		break;
115 	case 'm':
116 		maxto = strtol(EARGF(usage()), (char **)nil, 0);
117 		break;
118 	default:
119 		usage();
120 	} ARGEND;
121 
122 	if (argc != 1)
123 		usage();
124 
125 	if (!client) {
126 		devdir = argv[0];
127 		if ((p = strstr(devdir, "/local")) != nil)
128 			*p = '\0';
129 	}else
130 		dialstring = argv[0];
131 
132 	if (debug > 0) {
133 		fd = open("#c/cons", OWRITE|OCEXEC);
134 		dup(fd, 2);
135 	}
136 
137 	fmtinstall('F', fcallfmt);
138 
139 	atnotify(catch, 1);
140 
141 	unsent = chancreate(sizeof(Buf *), Nbuf);
142 	unacked= chancreate(sizeof(Buf *), Nbuf);
143 	empty  = chancreate(sizeof(Buf *), Nbuf);
144 	timer  = chancreate(sizeof(uchar *), 1);
145 
146 	for (i = 0; i != Nbuf; i++) {
147 		eb = malloc(sizeof(Buf));
148 		sendp(empty, eb);
149 	}
150 
151 	netfd = -1;
152 
153 	if (proccreate(fromnet, nil, Stacksize) < 0)
154 		sysfatal("Cannot start fromnet; %r");
155 
156 	reconnect();		/* Set up the initial connection. */
157 	synchronize();
158 
159 	if (proccreate(fromclient, nil, Stacksize) < 0)
160 		sysfatal("cannot start fromclient; %r");
161 
162 	if (proccreate(timerproc, timer, Stacksize) < 0)
163 		sysfatal("Cannot start timerproc; %r");
164 
165 	a[Timer].c = timer;
166 	a[Unsent].c = unsent;
167 	a[Unsent].v = &b;
168 
169 	synctime = nsec() + Synctime;
170 	failed = 0;
171 	while (!done) {
172 		if (failed) {
173 			/* Wait for the netreader to die. */
174 			while (netfd >= 0) {
175 				dmessage(1, "main; waiting for netreader to die\n");
176 				sleep(1000);
177 			}
178 
179 			/* the reader died; reestablish the world. */
180 			reconnect();
181 			synchronize();
182 			failed = 0;
183 		}
184 
185 		now = nsec();
186 		delta = (synctime - nsec()) / MS(1);
187 
188 		if (delta <= 0) {
189 			hdr.nb = 0;
190 			hdr.acked = inmsg;
191 			hdr.msg = -1;
192 			packhdr(&hdr, buf);
193 			if (writen(netfd, buf, sizeof(buf)) < 0) {
194 				dmessage(2, "main; writen failed; %r\n");
195 				failed = 1;
196 				continue;
197 			}
198 			synctime = nsec() + Synctime;
199 			assert(synctime > now);
200 		}
201 
202 		switch (alt(a)) {
203 		case Timer:
204 			break;
205 		case Unsent:
206 			sendp(unacked, b);
207 
208 			b->hdr.acked = inmsg;
209 			packhdr(&b->hdr, buf);
210 			if (writen(netfd, buf, sizeof(buf)) < 0 ||
211 			    writen(netfd, b->buf, b->hdr.nb) < 0) {
212 				dmessage(2, "main; writen failed; %r\n");
213 				failed = 1;
214 			}
215 
216 			if (b->hdr.nb == 0)
217 				done = 1;
218 			break;
219 		}
220 	}
221 	syslog(0, Logname, "exiting...");
222 	threadexitsall(nil);
223 }
224 
225 static void
fromclient(void *)226 fromclient(void*)
227 {
228 	Buf *b;
229 	static int outmsg;
230 
231 	do {
232 		b = recvp(empty);
233 		if ((int)(b->hdr.nb = read(0, b->buf, Bufsize)) <= 0) {
234 			if ((int)b->hdr.nb < 0)
235 				dmessage(2, "fromclient; Cannot read 9P message; %r\n");
236 			else
237 				dmessage(2, "fromclient; Client terminated\n");
238 			b->hdr.nb = 0;
239 		}
240 		b->hdr.msg = outmsg++;
241 
242 		showmsg(1, "fromclient", b);
243 		sendp(unsent, b);
244 	} while (b->hdr.nb != 0);
245 }
246 
247 static void
fromnet(void *)248 fromnet(void*)
249 {
250 	int len, acked, i;
251 	uchar buf[Hdrsz];
252 	Buf *b, *rb;
253 	static int lastacked;
254 
255 	b = (Buf *)malloc(sizeof(Buf));
256 	assert(b);
257 
258 	while (!done) {
259 		while (netfd < 0) {
260 			dmessage(1, "fromnet; waiting for connection... (inmsg %d)\n",
261 					  inmsg);
262 			sleep(1000);
263 		}
264 
265 		/* Read the header. */
266 		if ((len = readn(netfd, buf, sizeof(buf))) <= 0) {
267 			if (len < 0)
268 				dmessage(1, "fromnet; (hdr) network failure; %r\n");
269 			else
270 				dmessage(1, "fromnet; (hdr) network closed\n");
271 			close(netfd);
272 			netfd = -1;
273 			continue;
274 		}
275 		unpackhdr(&b->hdr, buf);
276 		dmessage(2, "fromnet: Got message, size %d, nb %d, msg %d\n",
277 			len, b->hdr.nb, b->hdr.msg);
278 
279 		if (b->hdr.nb == 0) {
280 			if  ((long)b->hdr.msg >= 0) {
281 				dmessage(1, "fromnet; network closed\n");
282 				break;
283 			}
284 			continue;
285 		}
286 
287 		if ((len = readn(netfd, b->buf, b->hdr.nb)) <= 0 ||
288 		    len != b->hdr.nb) {
289 			if (len == 0)
290 				dmessage(1, "fromnet; network closed\n");
291 			else
292 				dmessage(1, "fromnet; network failure; %r\n");
293 			close(netfd);
294 			netfd = -1;
295 			continue;
296 		}
297 
298 		if (b->hdr.msg < inmsg) {
299 			dmessage(1, "fromnet; skipping message %d, currently at %d\n",
300 				b->hdr.msg, inmsg);
301 			continue;
302 		}
303 
304 		/* Process the acked list. */
305 		acked = b->hdr.acked - lastacked;
306 		for (i = 0; i != acked; i++) {
307 			rb = recvp(unacked);
308 			if (rb->hdr.msg != lastacked + i) {
309 				dmessage(1, "rb %p, msg %d, lastacked %d, i %d\n",
310 					rb, rb? rb->hdr.msg: -2, lastacked, i);
311 				assert(0);
312 			}
313 			rb->hdr.msg = -1;
314 			sendp(empty, rb);
315 		}
316 		lastacked = b->hdr.acked;
317 		inmsg++;
318 		showmsg(1, "fromnet", b);
319 		if (writen(1, b->buf, len) < 0)
320 			sysfatal("fromnet; cannot write to client; %r");
321 	}
322 	done = 1;
323 }
324 
325 static void
reconnect(void)326 reconnect(void)
327 {
328 	char err[32], ldir[40];
329 	int lcfd, fd;
330 	Endpoints *ep;
331 
332 	if (dialstring) {
333 		syslog(0, Logname, "dialing %s", dialstring);
334   		while ((fd = dial(dialstring, nil, nil, nil)) < 0) {
335 			err[0] = '\0';
336 			errstr(err, sizeof err);
337 			if (strstr(err, "connection refused")) {
338 				dmessage(1, "reconnect; server died...\n");
339 				threadexitsall("server died...");
340 			}
341 			dmessage(1, "reconnect: dialed %s; %s\n", dialstring, err);
342 			sleep(1000);
343 		}
344 		syslog(0, Logname, "reconnected to %s", dialstring);
345 	} else {
346 		syslog(0, Logname, "waiting for connection on %s", devdir);
347 		alarm(maxto * 1000);
348  		if ((lcfd = listen(devdir, ldir)) < 0)
349 			sysfatal("reconnect; cannot listen; %r");
350 
351 		if ((fd = accept(lcfd, ldir)) < 0)
352 			sysfatal("reconnect; cannot accept; %r");
353 		alarm(0);
354 		close(lcfd);
355 
356 		ep = getendpoints(ldir);
357 		dmessage(1, "rsys '%s'\n", ep->rsys);
358 		syslog(0, Logname, "connected from %s", ep->rsys);
359 		freeendpoints(ep);
360 	}
361 	netfd = fd;			/* Wakes up the netreader. */
362 }
363 
364 static void
synchronize(void)365 synchronize(void)
366 {
367 	Channel *tmp;
368 	Buf *b;
369 	uchar buf[Hdrsz];
370 
371 	/*
372 	 * Ignore network errors here.  If we fail during
373 	 * synchronization, the next alarm will pick up
374 	 * the error.
375 	 */
376 	tmp = chancreate(sizeof(Buf *), Nbuf);
377 	while ((b = nbrecvp(unacked)) != nil) {
378 		packhdr(&b->hdr, buf);
379 		writen(netfd, buf, sizeof(buf));
380 		writen(netfd, b->buf, b->hdr.nb);
381 		sendp(tmp, b);
382 	}
383 	chanfree(unacked);
384 	unacked = tmp;
385 }
386 
387 static void
showmsg(int level,char * s,Buf * b)388 showmsg(int level, char *s, Buf *b)
389 {
390 	if (b == nil) {
391 		dmessage(level, "%s; b == nil\n", s);
392 		return;
393 	}
394 	dmessage(level, "%s;  (len %d) %X %X %X %X %X %X %X %X %X (%p)\n", s,
395 		b->hdr.nb,
396 		b->buf[0], b->buf[1], b->buf[2],
397 		b->buf[3], b->buf[4], b->buf[5],
398 		b->buf[6], b->buf[7], b->buf[8], b);
399 }
400 
401 static int
writen(int fd,uchar * buf,int nb)402 writen(int fd, uchar *buf, int nb)
403 {
404 	int n, len = nb;
405 
406 	while (nb > 0) {
407 		if (fd < 0)
408 			return -1;
409 		if ((n = write(fd, buf, nb)) < 0) {
410 			dmessage(1, "writen; Write failed; %r\n");
411 			return -1;
412 		}
413 		dmessage(2, "writen: wrote %d bytes\n", n);
414 		buf += n;
415 		nb -= n;
416 	}
417 	return len;
418 }
419 
420 static void
timerproc(void * x)421 timerproc(void *x)
422 {
423 	Channel *timer = x;
424 
425 	while (!done) {
426 		sleep((Synctime / MS(1)) >> 1);
427 		sendp(timer, "timer");
428 	}
429 }
430 
431 static void
dmessage(int level,char * fmt,...)432 dmessage(int level, char *fmt, ...)
433 {
434 	va_list arg;
435 
436 	if (level > debug)
437 		return;
438 	va_start(arg, fmt);
439 	vfprint(2, fmt, arg);
440 	va_end(arg);
441 }
442 
443 static void
getendpoint(char * dir,char * file,char ** sysp,char ** servp)444 getendpoint(char *dir, char *file, char **sysp, char **servp)
445 {
446 	int fd, n;
447 	char buf[128];
448 	char *sys, *serv;
449 
450 	sys = serv = 0;
451 	snprint(buf, sizeof buf, "%s/%s", dir, file);
452 	fd = open(buf, OREAD);
453 	if(fd >= 0){
454 		n = read(fd, buf, sizeof(buf)-1);
455 		if(n>0){
456 			buf[n-1] = 0;
457 			serv = strchr(buf, '!');
458 			if(serv){
459 				*serv++ = 0;
460 				serv = strdup(serv);
461 			}
462 			sys = strdup(buf);
463 		}
464 		close(fd);
465 	}
466 	if(serv == 0)
467 		serv = strdup("unknown");
468 	if(sys == 0)
469 		sys = strdup("unknown");
470 	*servp = serv;
471 	*sysp = sys;
472 }
473 
474 static Endpoints *
getendpoints(char * dir)475 getendpoints(char *dir)
476 {
477 	Endpoints *ep;
478 
479 	ep = malloc(sizeof(*ep));
480 	getendpoint(dir, "local", &ep->lsys, &ep->lserv);
481 	getendpoint(dir, "remote", &ep->rsys, &ep->rserv);
482 	return ep;
483 }
484 
485 static void
freeendpoints(Endpoints * ep)486 freeendpoints(Endpoints *ep)
487 {
488 	free(ep->lsys);
489 	free(ep->rsys);
490 	free(ep->lserv);
491 	free(ep->rserv);
492 	free(ep);
493 }
494 
495 /* p must be a uchar* */
496 #define	U32GET(p)	(p[0] | p[1]<<8 | p[2]<<16 | p[3]<<24)
497 #define	U32PUT(p,v)	(p)[0] = (v); (p)[1] = (v)>>8; \
498 			(p)[2] = (v)>>16; (p)[3] = (v)>>24
499 
500 static void
packhdr(Hdr * hdr,uchar * buf)501 packhdr(Hdr *hdr, uchar *buf)
502 {
503 	uchar *p;
504 
505 	p = buf;
506 	U32PUT(p, hdr->nb);
507 	p += 4;
508 	U32PUT(p, hdr->msg);
509 	p += 4;
510 	U32PUT(p, hdr->acked);
511 }
512 
513 static void
unpackhdr(Hdr * hdr,uchar * buf)514 unpackhdr(Hdr *hdr, uchar *buf)
515 {
516 	uchar *p;
517 
518 	p = buf;
519 	hdr->nb = U32GET(p);
520 	p += 4;
521 	hdr->msg = U32GET(p);
522 	p += 4;
523 	hdr->acked = U32GET(p);
524 }
525