xref: /plan9-contrib/sys/src/cmd/ssh2/dial.c (revision 63afb9a5d3f910047231762bcce0ee49fed3d07c)
1 /*
2  * dial - connect to a service (parallel version)
3  */
4 #include <u.h>
5 #include <libc.h>
6 #include <ctype.h>
7 
8 typedef struct Conn Conn;
9 typedef struct Dest Dest;
10 typedef struct DS DS;
11 
12 enum
13 {
14 	Maxstring	= 128,
15 	Maxpath		= 256,
16 
17 	Maxcsreply	= 64*80,	/* this is probably overly generous */
18 	/*
19 	 * this should be a plausible slight overestimate for non-interactive
20 	 * use even if it's ridiculously long for interactive use.
21 	 */
22 	Maxconnms	= 2*60*1000,	/* 2 minutes */
23 };
24 
25 struct DS {
26 	/* dist string */
27 	char	buf[Maxstring];
28 	char	*netdir;		/* e.g., /net.alt */
29 	char	*proto;			/* e.g., tcp */
30 	char	*rem;			/* e.g., host!service */
31 
32 	/* other args */
33 	char	*local;
34 	char	*dir;
35 	int	*cfdp;
36 };
37 
38 /*
39  * malloc these; they need to be writable by this proc & all children.
40  * the stack is private to each proc, and static allocation in the data
41  * segment would not permit concurrent dials within a multi-process program.
42  */
43 struct Conn {
44 	int	pid;
45 	int	dead;
46 
47 	int	dfd;
48 	int	cfd;
49 	char	dir[NETPATHLEN];
50 	char	err[ERRMAX];
51 };
52 struct Dest {
53 	Conn	*conn;			/* allocated array */
54 	Conn	*connend;
55 	int	nkid;
56 
57 	long	oalarm;
58 	int	naddrs;
59 
60 	QLock	winlck;
61 	int	winner;			/* index into conn[] */
62 
63 	char	*nextaddr;
64 	char	addrlist[Maxcsreply];
65 };
66 
67 static int	call(char*, char*, DS*, Dest*, Conn*);
68 static int	csdial(DS*);
69 static void	_dial_string_parse(char*, DS*);
70 
71 
72 /*
73  *  the dialstring is of the form '[/net/]proto!dest'
74  */
75 static int
76 dialimpl(char *dest, char *local, char *dir, int *cfdp)
77 {
78 	DS ds;
79 	int rv;
80 	char err[ERRMAX], alterr[ERRMAX];
81 
82 	ds.local = local;
83 	ds.dir = dir;
84 	ds.cfdp = cfdp;
85 
86 	_dial_string_parse(dest, &ds);
87 	if(ds.netdir)
88 		return csdial(&ds);
89 
90 	ds.netdir = "/net";
91 	rv = csdial(&ds);
92 	if(rv >= 0)
93 		return rv;
94 	err[0] = '\0';
95 	errstr(err, sizeof err);
96 	if(strstr(err, "refused") != 0){
97 		werrstr("%s", err);
98 		return rv;
99 	}
100 	ds.netdir = "/net.alt";
101 	rv = csdial(&ds);
102 	if(rv >= 0)
103 		return rv;
104 
105 	alterr[0] = 0;
106 	errstr(alterr, sizeof alterr);
107 	if(strstr(alterr, "translate") || strstr(alterr, "does not exist"))
108 		werrstr("%s", err);
109 	else
110 		werrstr("%s", alterr);
111 	return rv;
112 }
113 
114 /*
115  * the thread library can't cope with rfork(RFMEM|RFPROC),
116  * so it must override this with a private version of dial.
117  */
118 int (*_dial)(char *, char *, char *, int *) = dialimpl;
119 
120 int
121 dial(char *dest, char *local, char *dir, int *cfdp)
122 {
123 	return (*_dial)(dest, local, dir, cfdp);
124 }
125 
126 static int
127 connsalloc(Dest *dp, int addrs)
128 {
129 	Conn *conn;
130 
131 	free(dp->conn);
132 	dp->connend = nil;
133 	assert(addrs > 0);
134 
135 	dp->conn = mallocz(addrs * sizeof *dp->conn, 1);
136 	if(dp->conn == nil)
137 		return -1;
138 	dp->connend = dp->conn + addrs;
139 	for(conn = dp->conn; conn < dp->connend; conn++)
140 		conn->cfd = conn->dfd = -1;
141 	return 0;
142 }
143 
144 static void
145 freedest(Dest *dp)
146 {
147 	long oalarm;
148 
149 	if (dp == nil)
150 		return;
151 	oalarm = dp->oalarm;
152 	free(dp->conn);
153 	free(dp);
154 	if (oalarm >= 0)
155 		alarm(oalarm);
156 }
157 
158 static void
159 closeopenfd(int *fdp)
160 {
161 	if (*fdp >= 0) {
162 		close(*fdp);
163 		*fdp = -1;
164 	}
165 }
166 
167 static void
168 notedeath(Dest *dp, char *exitsts)
169 {
170 	int i, n, pid;
171 	char *fields[5];			/* pid + 3 times + error */
172 	Conn *conn;
173 
174 	for (i = 0; i < nelem(fields); i++)
175 		fields[i] = "";
176 	n = tokenize(exitsts, fields, nelem(fields));
177 	if (n < 4)
178 		return;
179 	pid = atoi(fields[0]);
180 	if (pid <= 0)
181 		return;
182 	for (conn = dp->conn; conn < dp->connend; conn++)
183 		if (conn->pid == pid && !conn->dead) {  /* it's one we know? */
184 			if (conn - dp->conn != dp->winner) {
185 				closeopenfd(&conn->dfd);
186 				closeopenfd(&conn->cfd);
187 			}
188 			strncpy(conn->err, fields[4], sizeof conn->err);
189 			conn->dead = 1;
190 			return;
191 		}
192 	/* not a proc that we forked */
193 }
194 
195 static int
196 outstandingprocs(Dest *dp)
197 {
198 	Conn *conn;
199 
200 	for (conn = dp->conn; conn < dp->connend; conn++)
201 		if (!conn->dead)
202 			return 1;
203 	return 0;
204 }
205 
206 static int
207 reap(Dest *dp)
208 {
209 	char exitsts[2*ERRMAX];
210 
211 	if (outstandingprocs(dp) && await(exitsts, sizeof exitsts) >= 0) {
212 		notedeath(dp, exitsts);
213 		return 0;
214 	}
215 	return -1;
216 }
217 
218 static int
219 fillinds(DS *ds, Dest *dp)
220 {
221 	Conn *conn;
222 
223 	if (dp->winner < 0)
224 		return -1;
225 	conn = &dp->conn[dp->winner];
226 	if (ds->cfdp)
227 		*ds->cfdp = conn->cfd;
228 	if (ds->dir)
229 		strncpy(ds->dir, conn->dir, NETPATHLEN);
230 	return conn->dfd;
231 }
232 
233 static int
234 connectwait(Dest *dp, char *besterr)
235 {
236 	Conn *conn;
237 
238 	/* wait for a winner or all attempts to time out */
239 	while (dp->winner < 0 && reap(dp) >= 0)
240 		;
241 
242 	/* kill all of our still-live kids & reap them */
243 	for (conn = dp->conn; conn < dp->connend; conn++)
244 		if (!conn->dead)
245 			postnote(PNPROC, conn->pid, "alarm");
246 	while (reap(dp) >= 0)
247 		;
248 
249 	/* rummage about and report some error string */
250 	for (conn = dp->conn; conn < dp->connend; conn++)
251 		if (conn - dp->conn != dp->winner && conn->dead &&
252 		    conn->err[0]) {
253 			strncpy(besterr, conn->err, ERRMAX);
254 			break;
255 		}
256 	return dp->winner;
257 }
258 
259 static int
260 parsecs(Dest *dp, char **clonep, char **destp)
261 {
262 	char *dest, *p;
263 
264 	dest = strchr(dp->nextaddr, ' ');
265 	if(dest == nil)
266 		return -1;
267 	*dest++ = '\0';
268 	p = strchr(dest, '\n');
269 	if(p == nil)
270 		return -1;
271 	*p++ = '\0';
272 	*clonep = dp->nextaddr;
273 	*destp = dest;
274 	dp->nextaddr = p;		/* advance to next line */
275 	return 0;
276 }
277 
278 static void
279 pickuperr(char *besterr, char *err)
280 {
281 	err[0] = '\0';
282 	errstr(err, ERRMAX);
283 	if(strstr(err, "does not exist") == 0)
284 		strcpy(besterr, err);
285 }
286 
287 static void
288 catcher(void *, char *s)
289 {
290 	if (strstr(s, "alarm") != nil)
291 		noted(NCONT);
292 	else
293 		noted(NDFLT);
294 }
295 
296 /*
297  * try all addresses in parallel and take the first one that answers;
298  * this helps when systems have ip v4 and v6 addresses but are
299  * only reachable from here on one (or some) of them.
300  */
301 static int
302 dialmulti(DS *ds, Dest *dp)
303 {
304 	int rv, kid, kidme;
305 	char *clone, *dest;
306 	char err[ERRMAX], besterr[ERRMAX];
307 
308 	dp->winner = -1;
309 	dp->nkid = 0;
310 	while(dp->winner < 0 && *dp->nextaddr != '\0' &&
311 	    parsecs(dp, &clone, &dest) >= 0) {
312 		kidme = dp->nkid++;		/* make private copy on stack */
313 		kid = rfork(RFPROC|RFMEM);	/* spin off a call attempt */
314 		if (kid < 0)
315 			--dp->nkid;
316 		else if (kid == 0) {
317 			/* only in kid, to avoid atnotify callbacks in parent */
318 			notify(catcher);
319 
320 			*besterr = '\0';
321 			rv = call(clone, dest, ds, dp, &dp->conn[kidme]);
322 			if(rv < 0)
323 				pickuperr(besterr, err);
324 			_exits(besterr);	/* avoid atexit callbacks */
325 		}
326 	}
327 	rv = connectwait(dp, besterr);
328 	if(rv < 0 && *besterr)
329 		werrstr("%s", besterr);
330 	else
331 		werrstr("%s", err);
332 	return rv;
333 }
334 
335 static int
336 csdial(DS *ds)
337 {
338 	int n, fd, rv, addrs, bleft;
339 	char c;
340 	char *addrp, *clone2, *dest;
341 	char buf[Maxstring], clone[Maxpath], err[ERRMAX], besterr[ERRMAX];
342 	Dest *dp;
343 
344 	dp = mallocz(sizeof *dp, 1);
345 	if(dp == nil)
346 		return -1;
347 	dp->winner = -1;
348 	dp->oalarm = alarm(0);
349 	if (connsalloc(dp, 1) < 0) {		/* room for a single conn. */
350 		freedest(dp);
351 		return -1;
352 	}
353 
354 	/*
355 	 *  open connection server
356 	 */
357 	snprint(buf, sizeof(buf), "%s/cs", ds->netdir);
358 	fd = open(buf, ORDWR);
359 	if(fd < 0){
360 		/* no connection server, don't translate */
361 		snprint(clone, sizeof(clone), "%s/%s/clone", ds->netdir, ds->proto);
362 		rv = call(clone, ds->rem, ds, dp, &dp->conn[0]);
363 		fillinds(ds, dp);
364 		freedest(dp);
365 		return rv;
366 	}
367 
368 	/*
369 	 *  ask connection server to translate
370 	 */
371 	snprint(buf, sizeof(buf), "%s!%s", ds->proto, ds->rem);
372 	if(write(fd, buf, strlen(buf)) < 0){
373 		close(fd);
374 		freedest(dp);
375 		return -1;
376 	}
377 
378 	/*
379 	 *  read all addresses from the connection server.
380 	 */
381 	seek(fd, 0, 0);
382 	addrs = 0;
383 	addrp = dp->nextaddr = dp->addrlist;
384 	bleft = sizeof dp->addrlist - 2;	/* 2 is room for \n\0 */
385 	while(bleft > 0 && (n = read(fd, addrp, bleft)) > 0) {
386 		if (addrp[n-1] != '\n')
387 			addrp[n++] = '\n';
388 		addrs++;
389 		addrp += n;
390 		bleft -= n;
391 	}
392 	/*
393 	 * if we haven't read all of cs's output, assume the last line might
394 	 * have been truncated and ignore it.  we really don't expect this
395 	 * to happen.
396 	 */
397 	if (addrs > 0 && bleft <= 0 && read(fd, &c, 1) == 1)
398 		addrs--;
399 	close(fd);
400 
401 	*besterr = 0;
402 	rv = -1;				/* pessimistic default */
403 	dp->naddrs = addrs;
404 	if (addrs == 0)
405 		werrstr("no address to dial");
406 	else if (addrs == 1) {
407 		/* common case: dial one address without forking */
408 		if (parsecs(dp, &clone2, &dest) >= 0 &&
409 		    (rv = call(clone2, dest, ds, dp, &dp->conn[0])) < 0) {
410 			pickuperr(besterr, err);
411 			werrstr("%s", besterr);
412 		}
413 	} else if (connsalloc(dp, addrs) >= 0)
414 		rv = dialmulti(ds, dp);
415 
416 	/* fill in results */
417 	if (rv >= 0 && dp->winner >= 0)
418 		rv = fillinds(ds, dp);
419 
420 	freedest(dp);
421 	return rv;
422 }
423 
424 static int
425 call(char *clone, char *dest, DS *ds, Dest *dp, Conn *conn)
426 {
427 	int fd, cfd, n, calleralarm, oalarm;
428 	char cname[Maxpath], name[Maxpath], data[Maxpath], *p;
429 
430 	/* because cs is in a different name space, replace the mount point */
431 	if(*clone == '/'){
432 		p = strchr(clone+1, '/');
433 		if(p == nil)
434 			p = clone;
435 		else
436 			p++;
437 	} else
438 		p = clone;
439 	snprint(cname, sizeof cname, "%s/%s", ds->netdir, p);
440 
441 	conn->pid = getpid();
442 	conn->cfd = cfd = open(cname, ORDWR);
443 	if(cfd < 0)
444 		return -1;
445 
446 	/* get directory name */
447 	n = read(cfd, name, sizeof(name)-1);
448 	if(n < 0){
449 		closeopenfd(&conn->cfd);
450 		return -1;
451 	}
452 	name[n] = 0;
453 	for(p = name; *p == ' '; p++)
454 		;
455 	snprint(name, sizeof(name), "%ld", strtoul(p, 0, 0));
456 	p = strrchr(cname, '/');
457 	*p = 0;
458 	if(ds->dir)
459 		snprint(conn->dir, NETPATHLEN, "%s/%s", cname, name);
460 	snprint(data, sizeof(data), "%s/%s/data", cname, name);
461 
462 	/* should be no alarm pending now; re-instate caller's alarm, if any */
463 	calleralarm = dp->oalarm > 0;
464 	if (calleralarm)
465 		alarm(dp->oalarm);
466 	else if (dp->naddrs > 1)	/* in a sub-process? */
467 		alarm(Maxconnms);
468 
469 	/* connect */
470 	if(ds->local)
471 		snprint(name, sizeof(name), "connect %s %s", dest, ds->local);
472 	else
473 		snprint(name, sizeof(name), "connect %s", dest);
474 	if(write(cfd, name, strlen(name)) < 0){
475 		closeopenfd(&conn->cfd);
476 		return -1;
477 	}
478 
479 	oalarm = alarm(0);	/* don't let alarm interrupt critical section */
480 	if (calleralarm)
481 		dp->oalarm = oalarm;	/* time has passed, so update user's */
482 
483 	/* open data connection */
484 	conn->dfd = fd = open(data, ORDWR);
485 	if(fd < 0){
486 		closeopenfd(&conn->cfd);
487 		alarm(dp->oalarm);
488 		return -1;
489 	}
490 	if(ds->cfdp == nil)
491 		closeopenfd(&conn->cfd);
492 
493 	n = conn - dp->conn;
494 	if (dp->winner < 0) {
495 		qlock(&dp->winlck);
496 		if (dp->winner < 0 && conn < dp->connend)
497 			dp->winner = n;
498 		qunlock(&dp->winlck);
499 	}
500 	alarm(calleralarm? dp->oalarm: 0);
501 	return fd;
502 }
503 
504 /*
505  * assume p points at first '!' in dial string.  st is start of dial string.
506  * there could be subdirs of the conn dirs (e.g., ssh/0) that must count as
507  * part of the proto string, so skip numeric components.
508  * returns pointer to delimiter after right-most non-numeric component.
509  */
510 static char *
511 backoverchans(char *st, char *p)
512 {
513 	char *sl;
514 
515 	for (sl = p; --p >= st && isascii(*p) && isdigit(*p); sl = p) {
516 		while (--p >= st && isascii(*p) && isdigit(*p))
517 			;
518 		if (p < st || *p != '/')
519 			break;			/* "net.alt2" or ran off start */
520 		while (p > st && p[-1] == '/')	/* skip runs of slashes */
521 			p--;
522 	}
523 	return sl;
524 }
525 
526 /*
527  *  parse a dial string
528  */
529 static void
530 _dial_string_parse(char *str, DS *ds)
531 {
532 	char *p, *p2;
533 
534 	strncpy(ds->buf, str, Maxstring);
535 	ds->buf[Maxstring-1] = 0;
536 
537 	p = strchr(ds->buf, '!');
538 	if(p == 0) {
539 		ds->netdir = 0;
540 		ds->proto = "net";
541 		ds->rem = ds->buf;
542 	} else {
543 		if(*ds->buf != '/' && *ds->buf != '#'){
544 			ds->netdir = 0;
545 			ds->proto = ds->buf;
546 		} else {
547 			p2 = backoverchans(ds->buf, p);
548 
549 			/* back over last component of netdir (proto) */
550 			while (--p2 > ds->buf && *p2 != '/')
551 				;
552 			*p2++ = 0;
553 			ds->netdir = ds->buf;
554 			ds->proto = p2;
555 		}
556 		*p = 0;
557 		ds->rem = p + 1;
558 	}
559 }
560