xref: /plan9/sys/src/cmd/ssh2/ssh2.c (revision 515d8088f74049961cf7be2f37185376da6b74b2)
1 /*
2  * ssh - remote login via SSH v2
3  *	/net/ssh does most of the work; we copy bytes back and forth
4  */
5 #include <u.h>
6 #include <libc.h>
7 #include <auth.h>
8 #include "ssh2.h"
9 
10 int doauth(int, char *);
11 int isatty(int);
12 
13 char *user, *remote;
14 char *netdir = "/net";
15 int debug = 0;
16 
17 static int stripcr = 0;
18 static int mflag = 0;
19 static int iflag = -1;
20 static int nopw = 0, nopka = 0;
21 static int chpid;
22 static int reqfd, dfd1, cfd1, dfd2, cfd2, consfd, kconsfd, cctlfd, notefd, keyfd;
23 
24 void
usage(void)25 usage(void)
26 {
27 	fprint(2, "usage: %s [-dkKmr] [-l user] [-n dir] [-z attr=val] addr "
28 		"[cmd [args]]\n", argv0);
29 	exits("usage");
30 }
31 
32 /*
33  * this is probably overkill except writing "kill" to notefd;
34  * file descriptors are closed by the kernel upon exit.
35  */
36 static void
shutdown(void)37 shutdown(void)
38 {
39 	if (cctlfd > 0) {
40 		fprint(cctlfd, "rawoff");
41 		close(cctlfd);
42 	}
43 	if (consfd > 0)
44 		close(consfd);
45 	if (reqfd > 0) {
46 		fprint(reqfd, "close");
47 		close(reqfd);
48 	}
49 	close(dfd2);
50 	close(dfd1);
51 	close(cfd2);
52 	close(cfd1);
53 
54 	fprint(notefd, "kill");
55 	close(notefd);
56 }
57 
58 static void
bail(char * sts)59 bail(char *sts)
60 {
61 	shutdown();
62 	exits(sts);
63 }
64 
65 int
handler(void *,char * s)66 handler(void *, char *s)
67 {
68 	char *nf;
69 	int fd;
70 
71 	if (strstr(s, "alarm") != nil)
72 		return 0;
73 	if (chpid) {
74 		nf = esmprint("/proc/%d/note", chpid);
75 		fd = open(nf, OWRITE);
76 		fprint(fd, "interrupt");
77 		close(fd);
78 		free(nf);
79 	}
80 	shutdown();
81 	return 1;
82 }
83 
84 static void
parseargs(void)85 parseargs(void)
86 {
87 	int n;
88 	char *p, *q;
89 
90 	q = strchr(remote, '@');
91 	if (q != nil) {
92 		user = remote;
93 		*q++ = 0;
94 		remote = q;
95 	}
96 
97 	q = strchr(remote, '!');
98 	if (q) {
99 		n = q - remote;
100 		netdir = malloc(n+1);
101 		if (netdir == nil)
102 			sysfatal("out of memory");
103 		strncpy(netdir, remote, n+1);
104 		netdir[n] = '\0';
105 
106 		p = strrchr(netdir, '/');
107 		if (p == nil) {
108 			free(netdir);
109 			netdir = "/net";
110 		} else if (strcmp(p+1, "ssh") == 0)
111 			*p = '\0';
112 		else
113 			remote = esmprint("%s/ssh", netdir);
114 	}
115 
116 }
117 
118 static int
catcher(void *,char * s)119 catcher(void *, char *s)
120 {
121 	return strstr(s, "alarm") != nil;
122 }
123 
124 static int
timedmount(int fd,int afd,char * mntpt,int flag,char * aname)125 timedmount(int fd, int afd, char *mntpt, int flag, char *aname)
126 {
127 	int oalarm, ret;
128 
129 	atnotify(catcher, 1);
130 	oalarm = alarm(5*1000);		/* don't get stuck here */
131 	ret = mount(fd, afd, mntpt, flag, aname);
132 	alarm(oalarm);
133 	atnotify(catcher, 0);
134 	return ret;
135 }
136 
137 static void
mounttunnel(char * srv)138 mounttunnel(char *srv)
139 {
140 	int fd;
141 
142 	if (debug)
143 		fprint(2, "%s: mounting %s on /net\n", argv0, srv);
144 	fd = open(srv, OREAD);
145 	if (fd < 0) {
146 		if (debug)
147 			fprint(2, "%s: can't open %s: %r\n", argv0, srv);
148 	} else if (timedmount(fd, -1, netdir, MBEFORE, "") < 0) {
149 		fprint(2, "can't mount %s on %s: %r\n", srv, netdir);
150 		close(fd);
151 	}
152 }
153 
154 static void
newtunnel(char * myname)155 newtunnel(char *myname)
156 {
157 	int kid, pid;
158 
159 	if(debug)
160 		fprint(2, "%s: starting new netssh for key access\n", argv0);
161 	kid = rfork(RFPROC|RFNOTEG|RFENVG /* |RFFDG */);
162 	if (kid == 0) {
163 //		for (fd = 3; fd < 40; fd++)
164 //			close(fd);
165 		execl("/bin/netssh", "netssh", "-m", netdir, "-s", myname, nil);
166 		sysfatal("no /bin/netssh: %r");
167 	} else if (kid < 0)
168 		sysfatal("fork failed: %r");
169 	while ((pid = waitpid()) != kid && pid >= 0)
170 		;
171 }
172 
173 static void
starttunnel(void)174 starttunnel(void)
175 {
176 	char *keys, *mysrv, *myname;
177 
178 	keys = esmprint("%s/ssh/keys", netdir);
179 	myname = esmprint("ssh.%s", getuser());
180 	mysrv = esmprint("/srv/%s", myname);
181 
182 	if (access(keys, ORDWR) < 0)
183 		mounttunnel("/srv/netssh");		/* old name */
184 	if (access(keys, ORDWR) < 0)
185 		mounttunnel("/srv/ssh");
186 	if (access(keys, ORDWR) < 0)
187 		mounttunnel(mysrv);
188 	if (access(keys, ORDWR) < 0)
189 		newtunnel(myname);
190 	if (access(keys, ORDWR) < 0)
191 		mounttunnel(mysrv);
192 
193 	/* if we *still* can't see our own tunnel, throw a tantrum. */
194 	if (access(keys, ORDWR) < 0)
195 		sysfatal("%s inaccessible: %r", keys);		/* WTF? */
196 
197 	free(myname);
198 	free(mysrv);
199 	free(keys);
200 }
201 
202 int
cmdmode(void)203 cmdmode(void)
204 {
205 	int n, m;
206 	char buf[Arbbufsz];
207 
208 	for(;;) {
209 reprompt:
210 		print("\n>>> ");
211 		n = 0;
212 		do {
213 			m = read(0, buf + n, sizeof buf - n - 1);
214 			if (m <= 0)
215 				return 1;
216 			write(1, buf + n, m);
217 			n += m;
218 			buf[n] = '\0';
219 			if (buf[n-1] == ('u' & 037))
220 				goto reprompt;
221 		} while (buf[n-1] != '\n' && buf[n-1] != '\r');
222 		switch (buf[0]) {
223 		case '\n':
224 		case '\r':
225 			break;
226 		case 'q':
227 			return 1;
228 		case 'c':
229 			return 0;
230 		case 'r':
231 			stripcr = !stripcr;
232 			return 0;
233 		case 'h':
234 			print("c - continue\n");
235 			print("h - help\n");
236 			print("q - quit\n");
237 			print("r - toggle carriage return stripping\n");
238 			break;
239 		default:
240 			print("unknown command\n");
241 			break;
242 		}
243 	}
244 }
245 
246 static void
keyprompt(char * buf,int size,int n)247 keyprompt(char *buf, int size, int n)
248 {
249 	if (*buf == 'c') {
250 		fprint(kconsfd, "The following key has been offered by the server:\n");
251 		write(kconsfd, buf+5, n);
252 		fprint(kconsfd, "\n\n");
253 		fprint(kconsfd, "Add this key? (yes, no, session) ");
254 	} else {
255 		fprint(kconsfd, "The following key does NOT match the known "
256 			"key(s) for the server:\n");
257 		write(kconsfd, buf+5, n);
258 		fprint(kconsfd, "\n\n");
259 		fprint(kconsfd, "Add this key? (yes, no, session, replace) ");
260 	}
261 	n = read(kconsfd, buf, size - 1);
262 	if (n <= 0)
263 		return;
264 	write(keyfd, buf, n);		/* user's response -> /net/ssh/keys */
265 	seek(keyfd, 0, 2);
266 	if (readn(keyfd, buf, 5) <= 0)
267 		return;
268 	buf[5] = 0;
269 	n = strtol(buf+1, nil, 10);
270 	n = readn(keyfd, buf+5, n);
271 	if (n <= 0)
272 		return;
273 	buf[n+5] = 0;
274 
275 	switch (*buf) {
276 	case 'b':
277 	case 'f':
278 		fprint(kconsfd, "%s\n", buf+5);
279 	case 'o':
280 		close(keyfd);
281 		close(kconsfd);
282 	}
283 }
284 
285 /* talk the undocumented /net/ssh/keys protocol */
286 static void
keyproc(char * buf,int size)287 keyproc(char *buf, int size)
288 {
289 	int n;
290 	char *p;
291 
292 	if (size < 6)
293 		exits("keyproc buffer too small");
294 	p = esmprint("%s/ssh/keys", netdir);
295 	keyfd = open(p, ORDWR);
296 	if (keyfd < 0) {
297 		chpid = 0;
298 		sysfatal("failed to open ssh keys in %s: %r", p);
299 	}
300 
301 	kconsfd = open("/dev/cons", ORDWR);
302 	if (kconsfd < 0)
303 		nopw = 1;
304 
305 	buf[0] = 0;
306 	n = read(keyfd, buf, 5);		/* reading /net/ssh/keys */
307 	if (n < 0)
308 		sysfatal("%s read: %r", p);
309 	buf[5] = 0;
310 	n = strtol(buf+1, nil, 10);
311 	n = readn(keyfd, buf+5, n);
312 	buf[n < 0? 5: n+5] = 0;
313 	free(p);
314 
315 	switch (*buf) {
316 	case 'f':
317 		if (kconsfd >= 0)
318 			fprint(kconsfd, "%s\n", buf+5);
319 		/* fall through */
320 	case 'o':
321 		close(keyfd);
322 		if (kconsfd >= 0)
323 			close(kconsfd);
324 		break;
325 	default:
326 		if (kconsfd >= 0)
327 			keyprompt(buf, size, n);
328 		else {
329 			fprint(keyfd, "n");
330 			close(keyfd);
331 		}
332 		break;
333 	}
334 	chpid = 0;
335 	exits(nil);
336 }
337 
338 /*
339  * start a subproc to copy from network to stdout
340  * while we copy from stdin to network.
341  */
342 static void
bidircopy(char * buf,int size)343 bidircopy(char *buf, int size)
344 {
345 	int i, n, lstart;
346 	char *path, *p, *q;
347 
348 	rfork(RFNOTEG);
349 	path = esmprint("/proc/%d/notepg", getpid());
350 	notefd = open(path, OWRITE);
351 
352 	switch (rfork(RFPROC|RFMEM|RFNOWAIT)) {
353 	case 0:
354 		while ((n = read(dfd2, buf, size - 1)) > 0) {
355 			if (!stripcr)
356 				p = buf + n;
357 			else
358 				for (i = 0, p = buf, q = buf; i < n; ++i, ++q)
359 					if (*q != '\r')
360 						*p++ = *q;
361 			if (p != buf)
362 				write(1, buf, p-buf);
363 		}
364 		/*
365 		 * don't bother; it will be obvious when the user's prompt
366 		 * changes.
367 		 *
368 		 * fprint(2, "%s: Connection closed by server\n", argv0);
369 		 */
370 		break;
371 	default:
372 		lstart = 1;
373 		while ((n = read(0, buf, size - 1)) > 0) {
374 			if (!mflag && lstart && buf[0] == 0x1c)
375 				if (cmdmode())
376 					break;
377 				else
378 					continue;
379 			lstart = (buf[n-1] == '\n' || buf[n-1] == '\r');
380 			write(dfd2, buf, n);
381 		}
382 		/*
383 		 * don't bother; it will be obvious when the user's prompt
384 		 * changes.
385 		 *
386 		 * fprint(2, "%s: EOF on client side\n", argv0);
387 		 */
388 		break;
389 	case -1:
390 		fprint(2, "%s: fork error: %r\n", argv0);
391 		break;
392 	}
393 
394 	bail(nil);
395 }
396 
397 static int
connect(char * buf,int size)398 connect(char *buf, int size)
399 {
400 	int nfd, n;
401 	char *dir, *ds, *nf;
402 
403 	dir = esmprint("%s/ssh", netdir);
404 	ds = netmkaddr(remote, dir, "22");		/* tcp port 22 is ssh */
405 	free(dir);
406 
407 	dfd1 = dial(ds, nil, nil, &cfd1);
408 	if (dfd1 < 0) {
409 		fprint(2, "%s: dial conn %s: %r\n", argv0, ds);
410 		if (chpid) {
411 			nf = esmprint("/proc/%d/note", chpid);
412 			nfd = open(nf, OWRITE);
413 			fprint(nfd, "interrupt");
414 			close(nfd);
415 		}
416 		exits("can't dial");
417 	}
418 
419 	seek(cfd1, 0, 0);
420 	n = read(cfd1, buf, size - 1);
421 	buf[n >= 0? n: 0] = 0;
422 	return atoi(buf);
423 }
424 
425 static int
chanconnect(int conn,char * buf,int size)426 chanconnect(int conn, char *buf, int size)
427 {
428 	int n;
429 	char *path;
430 
431 	path = esmprint("%s/ssh/%d!session", netdir, conn);
432 	dfd2 = dial(path, nil, nil, &cfd2);
433 	if (dfd2 < 0) {
434 		fprint(2, "%s: dial chan %s: %r\n", argv0, path);
435 		bail("dial");
436 	}
437 	free(path);
438 
439 	n = read(cfd2, buf, size - 1);
440 	buf[n >= 0? n: 0] = 0;
441 	return atoi(buf);
442 }
443 
444 static void
remotecmd(int argc,char * argv[],int conn,int chan,char * buf,int size)445 remotecmd(int argc, char *argv[], int conn, int chan, char *buf, int size)
446 {
447 	int i;
448 	char *path, *q, *ep;
449 
450 	path = esmprint("%s/ssh/%d/%d/request", netdir, conn, chan);
451 	reqfd = open(path, OWRITE);
452 	if (reqfd < 0)
453 		bail("can't open request chan");
454 	if (argc == 0)
455 		if (readfile("/env/TERM", buf, size) < 0)
456 			fprint(reqfd, "shell");
457 		else
458 			fprint(reqfd, "shell %s", buf);
459 	else {
460 		assert(size >= Bigbufsz);
461 		ep = buf + Bigbufsz;
462 		q = seprint(buf, ep, "exec");
463 		for (i = 0; i < argc; ++i)
464 			q = seprint(q, ep, " %q", argv[i]);
465 		if (q >= ep) {
466 			fprint(2, "%s: command too long\n", argv0);
467 			fprint(reqfd, "close");
468 			bail("cmd too long");
469 		}
470 		write(reqfd, buf, q - buf);
471 	}
472 }
473 
474 void
main(int argc,char * argv[])475 main(int argc, char *argv[])
476 {
477 	char *whichkey;
478 	int conn, chan, n;
479 	char buf[Copybufsz];
480 
481 	quotefmtinstall();
482 	reqfd = dfd1 = cfd1 = dfd2 = cfd2 = consfd = kconsfd = cctlfd =
483 		notefd = keyfd = -1;
484 	whichkey = nil;
485 	ARGBEGIN {
486 	case 'A':			/* auth protos */
487 	case 'c':			/* ciphers */
488 		fprint(2, "%s: sorry, -%c is not supported\n", argv0, ARGC());
489 		break;
490 	case 'a':			/* compat? */
491 	case 'C':			/* cooked mode */
492 	case 'f':			/* agent forwarding */
493 	case 'p':			/* force pty */
494 	case 'P':			/* force no pty */
495 	case 'R':			/* force raw mode on pty */
496 	case 'v':			/* scp compat */
497 	case 'w':			/* send window-size changes */
498 	case 'x':			/* unix compat: no x11 forwarding */
499 		break;
500 	case 'd':
501 		debug++;
502 		break;
503 	case 'I':			/* non-interactive */
504 		iflag = 0;
505 		break;
506 	case 'i':			/* interactive: scp & rx do it */
507 		iflag = 1;
508 		break;
509 	case 'l':
510 	case 'u':
511 		user = EARGF(usage());
512 		break;
513 	case 'k':
514 		nopka = 1;
515 		break;
516 	case 'K':
517 		nopw = 1;
518 		break;
519 	case 'm':
520 		mflag = 1;
521 		break;
522 	case 'n':
523 		netdir = EARGF(usage());
524 		break;
525 	case 'r':
526 		stripcr = 1;
527 		break;
528 	case 'z':
529 		whichkey = EARGF(usage());
530 		break;
531 	default:
532 		usage();
533 	} ARGEND;
534 	if (argc == 0)
535 		usage();
536 
537 	if (iflag == -1)
538 		iflag = isatty(0);
539 	remote = *argv++;
540 	--argc;
541 
542 	parseargs();
543 
544 	if (!user)
545 		user = getuser();
546 	if (user == nil || remote == nil)
547 		sysfatal("out of memory");
548 
549 	starttunnel();
550 
551 	/* fork subproc to handle keys; don't wait for it */
552 	if ((n = rfork(RFPROC|RFMEM|RFFDG|RFNOWAIT)) == 0)
553 		keyproc(buf, sizeof buf);
554 	chpid = n;
555 	atnotify(handler, 1);
556 
557 	/* connect and learn connection number */
558 	conn = connect(buf, sizeof buf);
559 
560 	consfd = open("/dev/cons", ORDWR);
561 	cctlfd = open("/dev/consctl", OWRITE);
562 	fprint(cctlfd, "rawon");
563 	if (doauth(cfd1, whichkey) < 0)
564 		bail("doauth");
565 
566 	/* connect a channel of conn and learn channel number */
567 	chan = chanconnect(conn, buf, sizeof buf);
568 
569 	/* open request channel, request shell or command execution */
570 	remotecmd(argc, argv, conn, chan, buf, sizeof buf);
571 
572 	bidircopy(buf, sizeof buf);
573 }
574 
575 int
isatty(int fd)576 isatty(int fd)
577 {
578 	char buf[64];
579 
580 	buf[0] = '\0';
581 	fd2path(fd, buf, sizeof buf);
582 	return strlen(buf) >= 9 && strcmp(buf+strlen(buf)-9, "/dev/cons") == 0;
583 }
584 
585 int
doauth(int cfd1,char * whichkey)586 doauth(int cfd1, char *whichkey)
587 {
588 	UserPasswd *up;
589 	int n;
590 	char path[Arbpathlen];
591 
592  	if (!nopka) {
593 		if (whichkey)
594 			n = fprint(cfd1, "ssh-userauth K %q %q", user, whichkey);
595 		else
596 			n = fprint(cfd1, "ssh-userauth K %q", user);
597 		if (n >= 0)
598 			return 0;
599 	}
600 	if (nopw)
601 		return -1;
602 	up = auth_getuserpasswd(iflag? auth_getkey: nil,
603 		"proto=pass service=ssh server=%q user=%q", remote, user);
604 	if (up == nil) {
605 		fprint(2, "%s: didn't get password: %r\n", argv0);
606 		return -1;
607 	}
608 	n = fprint(cfd1, "ssh-userauth k %q %q", user, up->passwd);
609 	if (n >= 0)
610 		return 0;
611 
612 	path[0] = '\0';
613 	fd2path(cfd1, path, sizeof path);
614 	fprint(2, "%s: auth ctl msg `ssh-userauth k %q <password>' for %q: %r\n",
615 		argv0, user, path);
616 	return -1;
617 }
618