1 #include <u.h>
2 #include <libc.h>
3 #include <auth.h>
4 #include <fcall.h>
5 #include <thread.h>
6
7 #define NS(x) ((vlong)x)
8 #define US(x) (NS(x) * 1000LL)
9 #define MS(x) (US(x) * 1000LL)
10 #define S(x) (MS(x) * 1000LL)
11
12 #define LOGNAME "aan"
13
14 enum {
15 Synctime = S(8),
16 Nbuf = 10,
17 K = 1024,
18 Bufsize = 8 * K,
19 Stacksize = 8 * K,
20 Timer = 0, /* Alt channels. */
21 Unsent = 1,
22 Maxto = 24 * 3600, /* A full day to reconnect. */
23 Hdrsz = 12,
24 };
25
26 typedef struct Endpoints Endpoints;
27 struct Endpoints {
28 char *lsys;
29 char *lserv;
30 char *rsys;
31 char *rserv;
32 };
33
34 typedef struct {
35 ulong nb; /* Number of data bytes in this message */
36 ulong msg; /* Message number */
37 ulong acked; /* Number of messages acked */
38 } Hdr;
39
40 typedef struct {
41 Hdr hdr;
42 uchar buf[Bufsize];
43 } Buf;
44
45 static char *Logname = LOGNAME;
46 static int client;
47 static int debug;
48 static char *devdir;
49 static char *dialstring;
50 static int done;
51 static int inmsg;
52 static int maxto = Maxto;
53 static int netfd;
54
55 static Channel *empty;
56 static Channel *unacked;
57 static Channel *unsent;
58
59 static Alt a[] = {
60 /* c v op */
61 { nil, nil, CHANRCV }, /* timer */
62 { nil, nil, CHANRCV }, /* unsent */
63 { nil, nil, CHANEND },
64 };
65
66 static void dmessage(int, char *, ...);
67 static void freeendpoints(Endpoints *);
68 static void fromclient(void*);
69 static void fromnet(void*);
70 static Endpoints *getendpoints(char *);
71 static void packhdr(Hdr *, uchar *);
72 static void reconnect(void);
73 static void showmsg(int, char *, Buf *);
74 static void synchronize(void);
75 static void timerproc(void *);
76 static void unpackhdr(Hdr *, uchar *);
77 static int writen(int, uchar *, int);
78
79 static void
usage(void)80 usage(void)
81 {
82 fprint(2, "Usage: %s [-cd] [-m maxto] dialstring|netdir\n", argv0);
83 threadexitsall("usage");
84 }
85
86 static int
catch(void *,char * s)87 catch(void *, char *s)
88 {
89 if (strstr(s, "alarm") != nil) {
90 syslog(0, Logname, "Timed out waiting for client on %s, exiting...",
91 devdir);
92 threadexitsall(nil);
93 }
94 return 0;
95 }
96
97 void
threadmain(int argc,char ** argv)98 threadmain(int argc, char **argv)
99 {
100 int i, fd, failed, delta;
101 vlong synctime, now;
102 char *p;
103 uchar buf[Hdrsz];
104 Buf *b, *eb;
105 Channel *timer;
106 Hdr hdr;
107
108 ARGBEGIN {
109 case 'c':
110 client++;
111 break;
112 case 'd':
113 debug++;
114 break;
115 case 'm':
116 maxto = strtol(EARGF(usage()), (char **)nil, 0);
117 break;
118 default:
119 usage();
120 } ARGEND;
121
122 if (argc != 1)
123 usage();
124
125 if (!client) {
126 devdir = argv[0];
127 if ((p = strstr(devdir, "/local")) != nil)
128 *p = '\0';
129 }else
130 dialstring = argv[0];
131
132 if (debug > 0) {
133 fd = open("#c/cons", OWRITE|OCEXEC);
134 dup(fd, 2);
135 }
136
137 fmtinstall('F', fcallfmt);
138
139 atnotify(catch, 1);
140
141 unsent = chancreate(sizeof(Buf *), Nbuf);
142 unacked= chancreate(sizeof(Buf *), Nbuf);
143 empty = chancreate(sizeof(Buf *), Nbuf);
144 timer = chancreate(sizeof(uchar *), 1);
145
146 for (i = 0; i != Nbuf; i++) {
147 eb = malloc(sizeof(Buf));
148 sendp(empty, eb);
149 }
150
151 netfd = -1;
152
153 if (proccreate(fromnet, nil, Stacksize) < 0)
154 sysfatal("Cannot start fromnet; %r");
155
156 reconnect(); /* Set up the initial connection. */
157 synchronize();
158
159 if (proccreate(fromclient, nil, Stacksize) < 0)
160 sysfatal("cannot start fromclient; %r");
161
162 if (proccreate(timerproc, timer, Stacksize) < 0)
163 sysfatal("Cannot start timerproc; %r");
164
165 a[Timer].c = timer;
166 a[Unsent].c = unsent;
167 a[Unsent].v = &b;
168
169 synctime = nsec() + Synctime;
170 failed = 0;
171 while (!done) {
172 if (failed) {
173 /* Wait for the netreader to die. */
174 while (netfd >= 0) {
175 dmessage(1, "main; waiting for netreader to die\n");
176 sleep(1000);
177 }
178
179 /* the reader died; reestablish the world. */
180 reconnect();
181 synchronize();
182 failed = 0;
183 }
184
185 now = nsec();
186 delta = (synctime - nsec()) / MS(1);
187
188 if (delta <= 0) {
189 hdr.nb = 0;
190 hdr.acked = inmsg;
191 hdr.msg = -1;
192 packhdr(&hdr, buf);
193 if (writen(netfd, buf, sizeof(buf)) < 0) {
194 dmessage(2, "main; writen failed; %r\n");
195 failed = 1;
196 continue;
197 }
198 synctime = nsec() + Synctime;
199 assert(synctime > now);
200 }
201
202 switch (alt(a)) {
203 case Timer:
204 break;
205 case Unsent:
206 sendp(unacked, b);
207
208 b->hdr.acked = inmsg;
209 packhdr(&b->hdr, buf);
210 if (writen(netfd, buf, sizeof(buf)) < 0 ||
211 writen(netfd, b->buf, b->hdr.nb) < 0) {
212 dmessage(2, "main; writen failed; %r\n");
213 failed = 1;
214 }
215
216 if (b->hdr.nb == 0)
217 done = 1;
218 break;
219 }
220 }
221 syslog(0, Logname, "exiting...");
222 threadexitsall(nil);
223 }
224
225 static void
fromclient(void *)226 fromclient(void*)
227 {
228 Buf *b;
229 static int outmsg;
230
231 do {
232 b = recvp(empty);
233 if ((int)(b->hdr.nb = read(0, b->buf, Bufsize)) <= 0) {
234 if ((int)b->hdr.nb < 0)
235 dmessage(2, "fromclient; Cannot read 9P message; %r\n");
236 else
237 dmessage(2, "fromclient; Client terminated\n");
238 b->hdr.nb = 0;
239 }
240 b->hdr.msg = outmsg++;
241
242 showmsg(1, "fromclient", b);
243 sendp(unsent, b);
244 } while (b->hdr.nb != 0);
245 }
246
247 static void
fromnet(void *)248 fromnet(void*)
249 {
250 int len, acked, i;
251 uchar buf[Hdrsz];
252 Buf *b, *rb;
253 static int lastacked;
254
255 b = (Buf *)malloc(sizeof(Buf));
256 assert(b);
257
258 while (!done) {
259 while (netfd < 0) {
260 dmessage(1, "fromnet; waiting for connection... (inmsg %d)\n",
261 inmsg);
262 sleep(1000);
263 }
264
265 /* Read the header. */
266 if ((len = readn(netfd, buf, sizeof(buf))) <= 0) {
267 if (len < 0)
268 dmessage(1, "fromnet; (hdr) network failure; %r\n");
269 else
270 dmessage(1, "fromnet; (hdr) network closed\n");
271 close(netfd);
272 netfd = -1;
273 continue;
274 }
275 unpackhdr(&b->hdr, buf);
276 dmessage(2, "fromnet: Got message, size %d, nb %d, msg %d\n",
277 len, b->hdr.nb, b->hdr.msg);
278
279 if (b->hdr.nb == 0) {
280 if ((long)b->hdr.msg >= 0) {
281 dmessage(1, "fromnet; network closed\n");
282 break;
283 }
284 continue;
285 }
286
287 if ((len = readn(netfd, b->buf, b->hdr.nb)) <= 0 ||
288 len != b->hdr.nb) {
289 if (len == 0)
290 dmessage(1, "fromnet; network closed\n");
291 else
292 dmessage(1, "fromnet; network failure; %r\n");
293 close(netfd);
294 netfd = -1;
295 continue;
296 }
297
298 if (b->hdr.msg < inmsg) {
299 dmessage(1, "fromnet; skipping message %d, currently at %d\n",
300 b->hdr.msg, inmsg);
301 continue;
302 }
303
304 /* Process the acked list. */
305 acked = b->hdr.acked - lastacked;
306 for (i = 0; i != acked; i++) {
307 rb = recvp(unacked);
308 if (rb->hdr.msg != lastacked + i) {
309 dmessage(1, "rb %p, msg %d, lastacked %d, i %d\n",
310 rb, rb? rb->hdr.msg: -2, lastacked, i);
311 assert(0);
312 }
313 rb->hdr.msg = -1;
314 sendp(empty, rb);
315 }
316 lastacked = b->hdr.acked;
317 inmsg++;
318 showmsg(1, "fromnet", b);
319 if (writen(1, b->buf, len) < 0)
320 sysfatal("fromnet; cannot write to client; %r");
321 }
322 done = 1;
323 }
324
325 static void
reconnect(void)326 reconnect(void)
327 {
328 char err[32], ldir[40];
329 int lcfd, fd;
330 Endpoints *ep;
331
332 if (dialstring) {
333 syslog(0, Logname, "dialing %s", dialstring);
334 while ((fd = dial(dialstring, nil, nil, nil)) < 0) {
335 err[0] = '\0';
336 errstr(err, sizeof err);
337 if (strstr(err, "connection refused")) {
338 dmessage(1, "reconnect; server died...\n");
339 threadexitsall("server died...");
340 }
341 dmessage(1, "reconnect: dialed %s; %s\n", dialstring, err);
342 sleep(1000);
343 }
344 syslog(0, Logname, "reconnected to %s", dialstring);
345 } else {
346 syslog(0, Logname, "waiting for connection on %s", devdir);
347 alarm(maxto * 1000);
348 if ((lcfd = listen(devdir, ldir)) < 0)
349 sysfatal("reconnect; cannot listen; %r");
350
351 if ((fd = accept(lcfd, ldir)) < 0)
352 sysfatal("reconnect; cannot accept; %r");
353 alarm(0);
354 close(lcfd);
355
356 ep = getendpoints(ldir);
357 dmessage(1, "rsys '%s'\n", ep->rsys);
358 syslog(0, Logname, "connected from %s", ep->rsys);
359 freeendpoints(ep);
360 }
361 netfd = fd; /* Wakes up the netreader. */
362 }
363
364 static void
synchronize(void)365 synchronize(void)
366 {
367 Channel *tmp;
368 Buf *b;
369 uchar buf[Hdrsz];
370
371 /*
372 * Ignore network errors here. If we fail during
373 * synchronization, the next alarm will pick up
374 * the error.
375 */
376 tmp = chancreate(sizeof(Buf *), Nbuf);
377 while ((b = nbrecvp(unacked)) != nil) {
378 packhdr(&b->hdr, buf);
379 writen(netfd, buf, sizeof(buf));
380 writen(netfd, b->buf, b->hdr.nb);
381 sendp(tmp, b);
382 }
383 chanfree(unacked);
384 unacked = tmp;
385 }
386
387 static void
showmsg(int level,char * s,Buf * b)388 showmsg(int level, char *s, Buf *b)
389 {
390 if (b == nil) {
391 dmessage(level, "%s; b == nil\n", s);
392 return;
393 }
394 dmessage(level, "%s; (len %d) %X %X %X %X %X %X %X %X %X (%p)\n", s,
395 b->hdr.nb,
396 b->buf[0], b->buf[1], b->buf[2],
397 b->buf[3], b->buf[4], b->buf[5],
398 b->buf[6], b->buf[7], b->buf[8], b);
399 }
400
401 static int
writen(int fd,uchar * buf,int nb)402 writen(int fd, uchar *buf, int nb)
403 {
404 int n, len = nb;
405
406 while (nb > 0) {
407 if (fd < 0)
408 return -1;
409 if ((n = write(fd, buf, nb)) < 0) {
410 dmessage(1, "writen; Write failed; %r\n");
411 return -1;
412 }
413 dmessage(2, "writen: wrote %d bytes\n", n);
414 buf += n;
415 nb -= n;
416 }
417 return len;
418 }
419
420 static void
timerproc(void * x)421 timerproc(void *x)
422 {
423 Channel *timer = x;
424
425 while (!done) {
426 sleep((Synctime / MS(1)) >> 1);
427 sendp(timer, "timer");
428 }
429 }
430
431 static void
dmessage(int level,char * fmt,...)432 dmessage(int level, char *fmt, ...)
433 {
434 va_list arg;
435
436 if (level > debug)
437 return;
438 va_start(arg, fmt);
439 vfprint(2, fmt, arg);
440 va_end(arg);
441 }
442
443 static void
getendpoint(char * dir,char * file,char ** sysp,char ** servp)444 getendpoint(char *dir, char *file, char **sysp, char **servp)
445 {
446 int fd, n;
447 char buf[128];
448 char *sys, *serv;
449
450 sys = serv = 0;
451 snprint(buf, sizeof buf, "%s/%s", dir, file);
452 fd = open(buf, OREAD);
453 if(fd >= 0){
454 n = read(fd, buf, sizeof(buf)-1);
455 if(n>0){
456 buf[n-1] = 0;
457 serv = strchr(buf, '!');
458 if(serv){
459 *serv++ = 0;
460 serv = strdup(serv);
461 }
462 sys = strdup(buf);
463 }
464 close(fd);
465 }
466 if(serv == 0)
467 serv = strdup("unknown");
468 if(sys == 0)
469 sys = strdup("unknown");
470 *servp = serv;
471 *sysp = sys;
472 }
473
474 static Endpoints *
getendpoints(char * dir)475 getendpoints(char *dir)
476 {
477 Endpoints *ep;
478
479 ep = malloc(sizeof(*ep));
480 getendpoint(dir, "local", &ep->lsys, &ep->lserv);
481 getendpoint(dir, "remote", &ep->rsys, &ep->rserv);
482 return ep;
483 }
484
485 static void
freeendpoints(Endpoints * ep)486 freeendpoints(Endpoints *ep)
487 {
488 free(ep->lsys);
489 free(ep->rsys);
490 free(ep->lserv);
491 free(ep->rserv);
492 free(ep);
493 }
494
495 /* p must be a uchar* */
496 #define U32GET(p) (p[0] | p[1]<<8 | p[2]<<16 | p[3]<<24)
497 #define U32PUT(p,v) (p)[0] = (v); (p)[1] = (v)>>8; \
498 (p)[2] = (v)>>16; (p)[3] = (v)>>24
499
500 static void
packhdr(Hdr * hdr,uchar * buf)501 packhdr(Hdr *hdr, uchar *buf)
502 {
503 uchar *p;
504
505 p = buf;
506 U32PUT(p, hdr->nb);
507 p += 4;
508 U32PUT(p, hdr->msg);
509 p += 4;
510 U32PUT(p, hdr->acked);
511 }
512
513 static void
unpackhdr(Hdr * hdr,uchar * buf)514 unpackhdr(Hdr *hdr, uchar *buf)
515 {
516 uchar *p;
517
518 p = buf;
519 hdr->nb = U32GET(p);
520 p += 4;
521 hdr->msg = U32GET(p);
522 p += 4;
523 hdr->acked = U32GET(p);
524 }
525