xref: /plan9/sys/src/cmd/ssh2/sshsession.c (revision 401314a3b4602c168a19b28ed47ba5cbefe42fe0)
1 /*
2  * ssh server - serve SSH protocol v2
3  *	/net/ssh does most of the work; we copy bytes back and forth
4  */
5 #include <u.h>
6 #include <libc.h>
7 #include <ip.h>
8 #include <auth.h>
9 #include "ssh2.h"
10 
11 char *confine(char *, char *);
12 char *get_string(char *, char *);
13 void newchannel(int, char *, int);
14 void runcmd(int, int, char *, char *, char *, char *);
15 
16 int errfd, toppid, sflag, tflag, prevent;
17 int debug;
18 char *idstring;
19 char *netdir = "/net";
20 char *nsfile = nil;
21 char *restdir;
22 char *shell;
23 char *srvpt;
24 char *uname;
25 
26 void
27 usage(void)
28 {
29 	fprint(2, "usage: %s [-i id] [-s shell] [-r dir] [-R dir] [-S srvpt] "
30 		"[-n ns] [-t] [netdir]\n", argv0);
31 	exits("usage");
32 }
33 
34 static int
35 getctlfd(void)
36 {
37 	int ctlfd;
38 	char *name;
39 
40 	name = smprint("%s/clone", netdir);
41 	ctlfd = -1;
42 	if (name)
43 		ctlfd = open(name, ORDWR);
44 	if (ctlfd < 0) {
45 		syslog(0, "ssh", "server can't clone: %s: %r", name);
46 		exits("open clone");
47 	}
48 	free(name);
49 	return ctlfd;
50 }
51 
52 static int
53 getdatafd(int ctlfd)
54 {
55 	int fd;
56 	char *name;
57 
58 	name = smprint("%s/data", netdir);
59 	fd = -1;
60 	if (name)
61 		fd = open(name, OREAD);
62 	if (fd < 0) {
63 		syslog(0, "ssh", "can't open %s: %r", name);
64 		hangup(ctlfd);
65 		exits("open data");
66 	}
67 	free(name);
68 	return fd;
69 }
70 
71 static void
72 auth(char *buf, int n, int ctlfd)
73 {
74 	int fd;
75 
76 	fd = open("#¤/capuse", OWRITE);
77 	if (fd < 0) {
78 		syslog(0, "ssh", "server can't open capuse: %r");
79 		hangup(ctlfd);
80 		exits("capuse");
81 	}
82 	if (write(fd, buf, n) != n) {
83 		syslog(0, "ssh", "server write `%.*s' to capuse failed: %r",
84 			n, buf);
85 		hangup(ctlfd);
86 		exits("capuse");
87 	}
88 	close(fd);
89 }
90 
91 /*
92  * start tunnel if there isn't one, though it's probably too late
93  * since caphash will likely be closed.
94  */
95 static void
96 mounttunnel(int ctlfd)
97 {
98 	int fd;
99 	char *p;
100 
101 	if (access(netdir, AEXIST) >= 0)
102 		return;
103 
104 	p = smprint("/srv/%s", srvpt? srvpt: "ssh");
105 	fd = -1;
106 	if (p)
107 		fd = open(p, ORDWR);
108 	if (fd < 0) {
109 		syslog(0, "ssh", "can't open %s: %r", p);
110 		hangup(ctlfd);
111 		exits("open");
112 	}
113 	if (mount(fd, -1, netdir, MBEFORE, "") < 0) {
114 		syslog(0, "ssh", "can't mount in /net: %r");
115 		hangup(ctlfd);
116 		exits("can't mount");
117 	}
118 }
119 
120 static int
121 authnewns(int ctlfd, char *buf, int size, int n)
122 {
123 	char *p, *q;
124 
125 	USED(size);
126 	if (n <= 0)
127 		return 0;
128 	buf[n] = '\0';
129 	if (strcmp(buf, "n/a") == 0)
130 		return 0;
131 
132 	auth(buf, n, ctlfd);
133 
134 	p = strchr(buf, '@');
135 	if (p == nil)
136 		return 0;
137 	++p;
138 	q = strchr(p, '@');
139 	if (q) {
140 		*q = '\0';
141 		uname = strdup(p);
142 	}
143 	if (!tflag && newns(p, nsfile) < 0) {
144 		syslog(0, "ssh", "server: newns(%s,%s) failed: %r", p, nsfile);
145 		return -1;
146 	}
147 	return 0;
148 }
149 
150 static void
151 listenloop(char *listfile, int ctlfd, char *buf, int size)
152 {
153 	int fd, n;
154 
155 	while ((fd = open(listfile, ORDWR)) >= 0) {
156 		n = read(fd, buf, size - 1);
157 		fprint(errfd, "read from listen file returned %d\n", n);
158 		if (n <= 0) {
159 			syslog(0, "ssh", "read on listen failed: %r");
160 			break;
161 		}
162 		buf[n >= 0? n: 0] = '\0';
163 		fprint(errfd, "read %s\n", buf);
164 
165 		switch (fork()) {
166 		case 0:					/* child */
167 			close(ctlfd);
168 			newchannel(fd, netdir, atoi(buf));  /* never returns */
169 		case -1:
170 			syslog(0, "ssh", "fork failed: %r");
171 			hangup(ctlfd);
172 			exits("fork");
173 		}
174 		close(fd);
175 	}
176 	if (fd < 0)
177 		syslog(0, "ssh", "listen failed: %r");
178 }
179 
180 void
181 main(int argc, char *argv[])
182 {
183 	char *listfile;
184 	int ctlfd, fd, n;
185 	char buf[Arbpathlen], path[Arbpathlen];
186 
187 	rfork(RFNOTEG);
188 	toppid = getpid();
189 	shell = "/bin/rc -il";
190 	ARGBEGIN {
191 	case 'd':
192 		debug++;
193 		break;
194 	case 'i':
195 		idstring = EARGF(usage());
196 		break;
197 	case 'n':
198 		nsfile = EARGF(usage());
199 		break;
200 	case 'R':
201 		prevent = 1;
202 		/* fall through */
203 	case 'r':
204 		restdir = EARGF(usage());
205 		break;
206 	case 's':
207 		sflag = 1;
208 		shell = EARGF(usage());
209 		break;
210 	case 'S':
211 		srvpt = EARGF(usage());
212 		break;
213 	case 't':
214 		tflag = 1;
215 		break;
216 	default:
217 		usage();
218 		break;
219 	} ARGEND;
220 
221 	errfd = -1;
222 	if (debug)
223 		errfd = 2;
224 
225 	/* work out network connection's directory */
226 	if (argc >= 1)
227 		netdir = argv[0];
228 	else				/* invoked by listen1 */
229 		netdir = getenv("net");
230 	if (netdir == nil) {
231 		syslog(0, "ssh", "server netdir is nil");
232 		exits("nil netdir");
233 	}
234 	syslog(0, "ssh", "server netdir is %s", netdir);
235 
236 	uname = getenv("user");
237 	if (uname == nil)
238 		uname = "none";
239 
240 	/* extract dfd and cfd from netdir */
241 	ctlfd = getctlfd();
242 	fd = getdatafd(ctlfd);
243 
244 	n = read(fd, buf, sizeof buf - 1);
245 	if (n < 0) {
246 		syslog(0, "ssh", "server read error for data file: %r");
247 		hangup(ctlfd);
248 		exits("read cap");
249 	}
250 	close(fd);
251 	authnewns(ctlfd, buf, sizeof buf, n);
252 
253 	/* get connection number in buf */
254 	n = read(ctlfd, buf, sizeof buf - 1);
255 	buf[n >= 0? n: 0] = '\0';
256 
257 	/* tell netssh our id string */
258 	fd2path(ctlfd, path, sizeof path);
259 	if (0 && idstring) {			/* was for coexistence */
260 		syslog(0, "ssh", "server conn %s, writing \"id %s\" to %s",
261 			buf, idstring, path);
262 		fprint(ctlfd, "id %s", idstring);
263 	}
264 
265 	/* announce */
266 	fprint(ctlfd, "announce session");
267 
268 	/* construct listen file name */
269 	listfile = smprint("%s/%s/listen", netdir, buf);
270 	if (listfile == nil) {
271 		syslog(0, "ssh", "out of memory");
272 		exits("out of memory");
273 	}
274 	syslog(0, "ssh", "server listen is %s", listfile);
275 
276 	mounttunnel(ctlfd);
277 	listenloop(listfile, ctlfd, buf, sizeof buf);
278 	hangup(ctlfd);
279 	exits(nil);
280 }
281 
282 /* an abbreviation.  note the assumed variables. */
283 #define REPLY(s)	if (want_reply) fprint(reqfd, s);
284 
285 static void
286 forkshell(char *cmd, int reqfd, int datafd, int want_reply)
287 {
288 	switch (fork()) {
289 	case 0:
290 		if (sflag)
291 			snprint(cmd, sizeof cmd, "-s%s", shell);
292 		else
293 			cmd[0] = '\0';
294 		USED(cmd);
295 		syslog(0, "ssh", "server starting ssh shell for %s", uname);
296 		/* runcmd doesn't return */
297 		runcmd(reqfd, datafd, "con", "/bin/ip/telnetd", "-nt", nil);
298 	case -1:
299 		REPLY("failure");
300 		syslog(0, "ssh", "server can't fork: %r");
301 		exits("fork");
302 	}
303 }
304 
305 static void
306 forkcmd(char *cmd, char *p, int reqfd, int datafd, int want_reply)
307 {
308 	char *q;
309 
310 	switch (fork()) {
311 	case 0:
312 		if (restdir && chdir(restdir) < 0) {
313 			syslog(0, "ssh", "can't chdir(%s): %r", restdir);
314 			exits("can't chdir");
315 		}
316 		if (!prevent || (q = getenv("sshsession")) &&
317 		    strcmp(q, "allow") == 0)
318 			get_string(p+1, cmd);
319 		else
320 			confine(p+1, cmd);
321 		syslog(0, "ssh", "server running `%s' for %s", cmd, uname);
322 		/* runcmd doesn't return */
323 		runcmd(reqfd, datafd, "rx", "/bin/rc", "-lc", cmd);
324 	case -1:
325 		REPLY("failure");
326 		syslog(0, "ssh", "server can't fork: %r");
327 		exits("fork");
328 	}
329 }
330 
331 void
332 newchannel(int fd, char *conndir, int channum)
333 {
334 	char *p, *reqfile, *datafile;
335 	int n, reqfd, datafd, want_reply, already_done;
336 	char buf[Maxpayload], cmd[Bigbufsz];
337 
338 	close(fd);
339 
340 	already_done = 0;
341 	reqfile = smprint("%s/%d/request", conndir, channum);
342 	if (reqfile == nil)
343 		sysfatal("out of memory");
344 	reqfd = open(reqfile, ORDWR);
345 	if (reqfd < 0) {
346 		syslog(0, "ssh", "can't open request file %s: %r", reqfile);
347 		exits("net");
348 	}
349 	datafile = smprint("%s/%d/data", conndir, channum);
350 	if (datafile == nil)
351 		sysfatal("out of memory");
352 	datafd = open(datafile, ORDWR);
353 	if (datafd < 0) {
354 		syslog(0, "ssh", "can't open data file %s: %r", datafile);
355 		exits("net");
356 	}
357 	while ((n = read(reqfd, buf, sizeof buf - 1)) > 0) {
358 		fprint(errfd, "read from request file returned %d\n", n);
359 		for (p = buf; p < buf + n && *p != ' '; ++p)
360 			;
361 		*p++ = '\0';
362 		want_reply = (*p == 't');
363 		/* shell, exec, and various flavours of failure */
364 		if (strcmp(buf, "shell") == 0) {
365 			if (already_done) {
366 				REPLY("failure");
367 				continue;
368 			}
369 			forkshell(cmd, reqfd, datafd, want_reply);
370 			already_done = 1;
371 			REPLY("success");
372 		} else if (strcmp(buf, "exec") == 0) {
373 			if (already_done) {
374 				REPLY("failure");
375 				continue;
376 			}
377 			forkcmd(cmd, p, reqfd, datafd, want_reply);
378 			already_done = 1;
379 			REPLY("success");
380 		} else if (strcmp(buf, "pty-req") == 0 ||
381 		    strcmp(buf, "window-change") == 0) {
382 			REPLY("success");
383 		} else if (strcmp(buf, "x11-req") == 0 ||
384 		    strcmp(buf, "env") == 0 || strcmp(buf, "subsystem") == 0) {
385 			REPLY("failure");
386 		} else if (strcmp(buf, "xon-xoff") == 0 ||
387 		    strcmp(buf, "signal") == 0 ||
388 		    strcmp(buf, "exit-status") == 0 ||
389 		    strcmp(buf, "exit-signal") == 0) {
390 			;
391 		} else
392 			syslog(0, "ssh", "server unknown channel request: %s",
393 				buf);
394 	}
395 	if (n < 0)
396 		syslog(0, "ssh", "server read failed: %r");
397 	exits(nil);
398 }
399 
400 char *
401 get_string(char *q, char *s)
402 {
403 	int n;
404 
405 	n = nhgetl(q);
406 	q += 4;
407 	memmove(s, q, n);
408 	s[n] = '\0';
409 	q += n;
410 	return q;
411 }
412 
413 char *
414 confine(char *q, char *s)
415 {
416 	int i, n, m;
417 	char *p, *e, *r, *buf, *toks[Maxtoks];
418 
419 	n = nhgetl(q);
420 	q += 4;
421 	buf = malloc(n+1);
422 	if (buf == nil)
423 		return nil;
424 	memmove(buf, q, n);
425 	buf[n]  = 0;
426 	m = tokenize(buf, toks, nelem(toks));
427 	e = s + n + 1;
428 	for (i = 0, r = s; i < m; ++i) {
429 		p = strrchr(toks[i], '/');
430 		if (p == nil)
431 			r = seprint(r, e, "%s ", toks[i]);
432 		else if (p[0] != '\0' && p[1] != '\0')
433 			r = seprint(r, e, "%s ", p+1);
434 		else
435 			r = seprint(r, e, ". ");
436 	}
437 	free(buf);
438 	q += n;
439 	return q;
440 }
441 
442 void
443 runcmd(int reqfd, int datafd, char *svc, char *cmd, char *arg1, char *arg2)
444 {
445 	char *p;
446 	int fd, cmdpid, child;
447 
448 	cmdpid = rfork(RFPROC|RFMEM|RFNOTEG|RFFDG|RFENVG);
449 	switch (cmdpid) {
450 	case -1:
451 		syslog(0, "ssh", "fork failed: %r");
452 		exits("fork");
453 	case 0:
454 		if (restdir == nil) {
455 			p = smprint("/usr/%s", uname);
456 			if (p && access(p, AREAD) == 0 && chdir(p) < 0) {
457 				syslog(0, "ssh", "can't chdir(%s): %r", p);
458 				exits("can't chdir");
459 			}
460 			free(p);
461 		}
462 		p = strrchr(cmd, '/');
463 		if (p)
464 			++p;
465 		else
466 			p = cmd;
467 
468 		dup(datafd, 0);
469 		dup(datafd, 1);
470 		dup(datafd, 2);
471 		close(datafd);
472 		putenv("service", svc);
473 		fprint(errfd, "starting %s\n", cmd);
474 		execl(cmd, p, arg1, arg2, nil);
475 
476 		syslog(0, "ssh", "cannot exec %s: %r", cmd);
477 		exits("exec");
478 	default:
479 		close(datafd);
480 		fprint(errfd, "waiting for child %d\n", cmdpid);
481 		while ((child = waitpid()) != cmdpid && child != -1)
482 			fprint(errfd, "child %d passed\n", child);
483 		if (child == -1)
484 			fprint(errfd, "wait failed: %r\n");
485 
486 		syslog(0, "ssh", "server closing ssh session for %s", uname);
487 		fprint(errfd, "closing connection\n");
488 		fprint(reqfd, "close");
489 		p = smprint("/proc/%d/notepg", toppid);
490 		if (p) {
491 			fd = open(p, OWRITE);
492 			fprint(fd, "interrupt");
493 			close(fd);
494 		}
495 		exits(nil);
496 	}
497 }
498