xref: /plan9/sys/src/cmd/ssh2/sshsession.c (revision 24d7e15e5e21820296ecc4fa071edfc938789ae0)
1 /*
2  * ssh server - serve SSH protocol v2
3  *	/net/ssh does most of the work; we copy bytes back and forth
4  */
5 #include <u.h>
6 #include <libc.h>
7 #include <ip.h>
8 #include <auth.h>
9 #include "ssh2.h"
10 
11 char *confine(char *, char *);
12 char *get_string(char *, char *);
13 void newchannel(int, char *, int);
14 void runcmd(int, int, char *, char *, char *, char *);
15 
16 int errfd, toppid, sflag, tflag, prevent;
17 int debug;
18 char *idstring;
19 char *netdir;				/* /net/ssh/<conn> */
20 char *nsfile = nil;
21 char *restdir;
22 char *shell;
23 char *srvpt;
24 char *uname;
25 
26 void
usage(void)27 usage(void)
28 {
29 	fprint(2, "usage: %s [-i id] [-s shell] [-r dir] [-R dir] [-S srvpt] "
30 		"[-n ns] [-t] [netdir]\n", argv0);
31 	exits("usage");
32 }
33 
34 static int
getctlfd(void)35 getctlfd(void)
36 {
37 	int ctlfd;
38 	char *name;
39 
40 	name = smprint("%s/clone", netdir);
41 	ctlfd = -1;
42 	if (name)
43 		ctlfd = open(name, ORDWR);
44 	if (ctlfd < 0) {
45 		syslog(0, "ssh", "server can't clone: %s: %r", name);
46 		exits("open clone");
47 	}
48 	free(name);
49 	return ctlfd;
50 }
51 
52 static int
getdatafd(int ctlfd)53 getdatafd(int ctlfd)
54 {
55 	int fd;
56 	char *name;
57 
58 	name = smprint("%s/data", netdir);
59 	fd = -1;
60 	if (name)
61 		fd = open(name, OREAD);
62 	if (fd < 0) {
63 		syslog(0, "ssh", "can't open %s: %r", name);
64 		hangup(ctlfd);
65 		exits("open data");
66 	}
67 	free(name);
68 	return fd;
69 }
70 
71 static void
auth(char * buf,int n,int ctlfd)72 auth(char *buf, int n, int ctlfd)
73 {
74 	int fd;
75 
76 	fd = open("#¤/capuse", OWRITE);
77 	if (fd < 0) {
78 		syslog(0, "ssh", "server can't open capuse: %r");
79 		hangup(ctlfd);
80 		exits("capuse");
81 	}
82 	if (write(fd, buf, n) != n) {
83 		syslog(0, "ssh", "server write `%.*s' to capuse failed: %r",
84 			n, buf);
85 		hangup(ctlfd);
86 		exits("capuse");
87 	}
88 	close(fd);
89 }
90 
91 /*
92  * mount tunnel if there isn't one visible.
93  */
94 static void
mounttunnel(int ctlfd)95 mounttunnel(int ctlfd)
96 {
97 	int fd;
98 	char *p, *np, *q;
99 
100 	if (access(netdir, AEXIST) >= 0)
101 		return;
102 
103 	p = smprint("/srv/%s", srvpt? srvpt: "ssh");
104 	np = strdup(netdir);
105 	if (p == nil || np == nil)
106 		sysfatal("out of memory");
107 	q = strstr(np, "/ssh");
108 	if (q != nil)
109 		*q = '\0';
110 	fd = open(p, ORDWR);
111 	if (fd < 0) {
112 		syslog(0, "ssh", "can't open %s: %r", p);
113 		hangup(ctlfd);
114 		exits("open");
115 	}
116 	if (mount(fd, -1, np, MBEFORE, "") < 0) {
117 		syslog(0, "ssh", "can't mount %s in %s: %r", p, np);
118 		hangup(ctlfd);
119 		exits("can't mount");
120 	}
121 	free(p);
122 	free(np);
123 }
124 
125 static int
authnewns(int ctlfd,char * buf,int size,int n)126 authnewns(int ctlfd, char *buf, int size, int n)
127 {
128 	char *p, *q;
129 
130 	USED(size);
131 	if (n <= 0)
132 		return 0;
133 	buf[n] = '\0';
134 	if (strcmp(buf, "n/a") == 0)
135 		return 0;
136 
137 	auth(buf, n, ctlfd);
138 
139 	p = strchr(buf, '@');
140 	if (p == nil)
141 		return 0;
142 	++p;
143 	q = strchr(p, '@');
144 	if (q) {
145 		*q = '\0';
146 		uname = strdup(p);
147 	}
148 	if (!tflag && newns(p, nsfile) < 0) {
149 		syslog(0, "ssh", "server: newns(%s,%s) failed: %r", p, nsfile);
150 		return -1;
151 	}
152 	return 0;
153 }
154 
155 static void
listenloop(char * listfile,int ctlfd,char * buf,int size)156 listenloop(char *listfile, int ctlfd, char *buf, int size)
157 {
158 	int fd, n;
159 
160 	while ((fd = open(listfile, ORDWR)) >= 0) {
161 		n = read(fd, buf, size - 1);
162 		fprint(errfd, "read from listen file returned %d\n", n);
163 		if (n <= 0) {
164 			syslog(0, "ssh", "read on listen failed: %r");
165 			break;
166 		}
167 		buf[n >= 0? n: 0] = '\0';
168 		fprint(errfd, "read %s\n", buf);
169 
170 		switch (fork()) {
171 		case 0:					/* child */
172 			close(ctlfd);
173 			newchannel(fd, netdir, atoi(buf));  /* never returns */
174 		case -1:
175 			syslog(0, "ssh", "fork failed: %r");
176 			hangup(ctlfd);
177 			exits("fork");
178 		}
179 		close(fd);
180 	}
181 	if (fd < 0)
182 		syslog(0, "ssh", "listen failed: %r");
183 }
184 
185 void
main(int argc,char * argv[])186 main(int argc, char *argv[])
187 {
188 	char *listfile;
189 	int ctlfd, fd, n;
190 	char buf[Arbpathlen], path[Arbpathlen];
191 
192 	rfork(RFNOTEG);
193 	toppid = getpid();
194 	shell = "/bin/rc -il";
195 	ARGBEGIN {
196 	case 'd':
197 		debug++;
198 		break;
199 	case 'i':
200 		idstring = EARGF(usage());
201 		break;
202 	case 'n':
203 		nsfile = EARGF(usage());
204 		break;
205 	case 'R':
206 		prevent = 1;
207 		/* fall through */
208 	case 'r':
209 		restdir = EARGF(usage());
210 		break;
211 	case 's':
212 		sflag = 1;
213 		shell = EARGF(usage());
214 		break;
215 	case 'S':
216 		srvpt = EARGF(usage());
217 		break;
218 	case 't':
219 		tflag = 1;
220 		break;
221 	default:
222 		usage();
223 		break;
224 	} ARGEND;
225 
226 	errfd = -1;
227 	if (debug)
228 		errfd = 2;
229 
230 	/* work out network connection's directory */
231 	if (argc >= 1)
232 		netdir = argv[0];
233 	else				/* invoked by listen1 */
234 		netdir = getenv("net");
235 	if (netdir == nil) {
236 		syslog(0, "ssh", "server netdir is nil");
237 		exits("nil netdir");
238 	}
239 	syslog(0, "ssh", "server netdir is %s", netdir);
240 
241 	uname = getenv("user");
242 	if (uname == nil)
243 		uname = "none";
244 
245 	/* extract dfd and cfd from netdir */
246 	ctlfd = getctlfd();
247 	fd = getdatafd(ctlfd);
248 
249 	n = read(fd, buf, sizeof buf - 1);
250 	if (n < 0) {
251 		syslog(0, "ssh", "server read error for data file: %r");
252 		hangup(ctlfd);
253 		exits("read cap");
254 	}
255 	close(fd);
256 	authnewns(ctlfd, buf, sizeof buf, n);
257 
258 	/* get connection number in buf */
259 	n = read(ctlfd, buf, sizeof buf - 1);
260 	buf[n >= 0? n: 0] = '\0';
261 
262 	/* tell netssh our id string */
263 	fd2path(ctlfd, path, sizeof path);
264 	if (0 && idstring) {			/* was for coexistence */
265 		syslog(0, "ssh", "server conn %s, writing \"id %s\" to %s",
266 			buf, idstring, path);
267 		fprint(ctlfd, "id %s", idstring);
268 	}
269 
270 	/* announce */
271 	fprint(ctlfd, "announce session");
272 
273 	/* construct listen file name */
274 	listfile = smprint("%s/%s/listen", netdir, buf);
275 	if (listfile == nil) {
276 		syslog(0, "ssh", "out of memory");
277 		exits("out of memory");
278 	}
279 	syslog(0, "ssh", "server listen is %s", listfile);
280 
281 	mounttunnel(ctlfd);
282 	listenloop(listfile, ctlfd, buf, sizeof buf);
283 	hangup(ctlfd);
284 	exits(nil);
285 }
286 
287 /* an abbreviation.  note the assumed variables. */
288 #define REPLY(s)	if (want_reply) fprint(reqfd, s);
289 
290 static void
forkshell(char * cmd,int reqfd,int datafd,int want_reply)291 forkshell(char *cmd, int reqfd, int datafd, int want_reply)
292 {
293 	switch (fork()) {
294 	case 0:
295 		if (sflag)
296 			snprint(cmd, sizeof cmd, "-s%s", shell);
297 		else
298 			cmd[0] = '\0';
299 		USED(cmd);
300 		syslog(0, "ssh", "server starting ssh shell for %s", uname);
301 		/* runcmd doesn't return */
302 		runcmd(reqfd, datafd, "con", "/bin/ip/telnetd", "-nt", nil);
303 	case -1:
304 		REPLY("failure");
305 		syslog(0, "ssh", "server can't fork: %r");
306 		exits("fork");
307 	}
308 }
309 
310 static void
forkcmd(char * cmd,char * p,int reqfd,int datafd,int want_reply)311 forkcmd(char *cmd, char *p, int reqfd, int datafd, int want_reply)
312 {
313 	char *q;
314 
315 	switch (fork()) {
316 	case 0:
317 		if (restdir && chdir(restdir) < 0) {
318 			syslog(0, "ssh", "can't chdir(%s): %r", restdir);
319 			exits("can't chdir");
320 		}
321 		if (!prevent || (q = getenv("sshsession")) &&
322 		    strcmp(q, "allow") == 0)
323 			get_string(p+1, cmd);
324 		else
325 			confine(p+1, cmd);
326 		syslog(0, "ssh", "server running `%s' for %s", cmd, uname);
327 		/* runcmd doesn't return */
328 		runcmd(reqfd, datafd, "rx", "/bin/rc", "-lc", cmd);
329 	case -1:
330 		REPLY("failure");
331 		syslog(0, "ssh", "server can't fork: %r");
332 		exits("fork");
333 	}
334 }
335 
336 void
newchannel(int fd,char * conndir,int channum)337 newchannel(int fd, char *conndir, int channum)
338 {
339 	char *p, *reqfile, *datafile;
340 	int n, reqfd, datafd, want_reply, already_done;
341 	char buf[Maxpayload], cmd[Bigbufsz];
342 
343 	close(fd);
344 
345 	already_done = 0;
346 	reqfile = smprint("%s/%d/request", conndir, channum);
347 	if (reqfile == nil)
348 		sysfatal("out of memory");
349 	reqfd = open(reqfile, ORDWR);
350 	if (reqfd < 0) {
351 		syslog(0, "ssh", "can't open request file %s: %r", reqfile);
352 		exits("net");
353 	}
354 	datafile = smprint("%s/%d/data", conndir, channum);
355 	if (datafile == nil)
356 		sysfatal("out of memory");
357 	datafd = open(datafile, ORDWR);
358 	if (datafd < 0) {
359 		syslog(0, "ssh", "can't open data file %s: %r", datafile);
360 		exits("net");
361 	}
362 	while ((n = read(reqfd, buf, sizeof buf - 1)) > 0) {
363 		fprint(errfd, "read from request file returned %d\n", n);
364 		for (p = buf; p < buf + n && *p != ' '; ++p)
365 			;
366 		*p++ = '\0';
367 		want_reply = (*p == 't');
368 		/* shell, exec, and various flavours of failure */
369 		if (strcmp(buf, "shell") == 0) {
370 			if (already_done) {
371 				REPLY("failure");
372 				continue;
373 			}
374 			forkshell(cmd, reqfd, datafd, want_reply);
375 			already_done = 1;
376 			REPLY("success");
377 		} else if (strcmp(buf, "exec") == 0) {
378 			if (already_done) {
379 				REPLY("failure");
380 				continue;
381 			}
382 			forkcmd(cmd, p, reqfd, datafd, want_reply);
383 			already_done = 1;
384 			REPLY("success");
385 		} else if (strcmp(buf, "pty-req") == 0 ||
386 		    strcmp(buf, "window-change") == 0) {
387 			REPLY("success");
388 		} else if (strcmp(buf, "x11-req") == 0 ||
389 		    strcmp(buf, "env") == 0 || strcmp(buf, "subsystem") == 0) {
390 			REPLY("failure");
391 		} else if (strcmp(buf, "xon-xoff") == 0 ||
392 		    strcmp(buf, "signal") == 0 ||
393 		    strcmp(buf, "exit-status") == 0 ||
394 		    strcmp(buf, "exit-signal") == 0) {
395 			;
396 		} else
397 			syslog(0, "ssh", "server unknown channel request: %s",
398 				buf);
399 	}
400 	if (n < 0)
401 		syslog(0, "ssh", "server read failed: %r");
402 	exits(nil);
403 }
404 
405 char *
get_string(char * q,char * s)406 get_string(char *q, char *s)
407 {
408 	int n;
409 
410 	n = nhgetl(q);
411 	q += 4;
412 	memmove(s, q, n);
413 	s[n] = '\0';
414 	q += n;
415 	return q;
416 }
417 
418 char *
confine(char * q,char * s)419 confine(char *q, char *s)
420 {
421 	int i, n, m;
422 	char *p, *e, *r, *buf, *toks[Maxtoks];
423 
424 	n = nhgetl(q);
425 	q += 4;
426 	buf = malloc(n+1);
427 	if (buf == nil)
428 		return nil;
429 	memmove(buf, q, n);
430 	buf[n]  = 0;
431 	m = tokenize(buf, toks, nelem(toks));
432 	e = s + n + 1;
433 	for (i = 0, r = s; i < m; ++i) {
434 		p = strrchr(toks[i], '/');
435 		if (p == nil)
436 			r = seprint(r, e, "%s ", toks[i]);
437 		else if (p[0] != '\0' && p[1] != '\0')
438 			r = seprint(r, e, "%s ", p+1);
439 		else
440 			r = seprint(r, e, ". ");
441 	}
442 	free(buf);
443 	q += n;
444 	return q;
445 }
446 
447 void
runcmd(int reqfd,int datafd,char * svc,char * cmd,char * arg1,char * arg2)448 runcmd(int reqfd, int datafd, char *svc, char *cmd, char *arg1, char *arg2)
449 {
450 	char *p;
451 	int fd, cmdpid, child;
452 
453 	cmdpid = rfork(RFPROC|RFMEM|RFNOTEG|RFFDG|RFENVG);
454 	switch (cmdpid) {
455 	case -1:
456 		syslog(0, "ssh", "fork failed: %r");
457 		exits("fork");
458 	case 0:
459 		if (restdir == nil) {
460 			p = smprint("/usr/%s", uname);
461 			if (p && access(p, AREAD) == 0 && chdir(p) < 0) {
462 				syslog(0, "ssh", "can't chdir(%s): %r", p);
463 				exits("can't chdir");
464 			}
465 			free(p);
466 		}
467 		p = strrchr(cmd, '/');
468 		if (p)
469 			++p;
470 		else
471 			p = cmd;
472 
473 		dup(datafd, 0);
474 		dup(datafd, 1);
475 		dup(datafd, 2);
476 		close(datafd);
477 		putenv("service", svc);
478 		fprint(errfd, "starting %s\n", cmd);
479 		execl(cmd, p, arg1, arg2, nil);
480 
481 		syslog(0, "ssh", "cannot exec %s: %r", cmd);
482 		exits("exec");
483 	default:
484 		close(datafd);
485 		fprint(errfd, "waiting for child %d\n", cmdpid);
486 		while ((child = waitpid()) != cmdpid && child != -1)
487 			fprint(errfd, "child %d passed\n", child);
488 		if (child == -1)
489 			fprint(errfd, "wait failed: %r\n");
490 
491 		syslog(0, "ssh", "server closing ssh session for %s", uname);
492 		fprint(errfd, "closing connection\n");
493 		fprint(reqfd, "close");
494 		p = smprint("/proc/%d/notepg", toppid);
495 		if (p) {
496 			fd = open(p, OWRITE);
497 			fprint(fd, "interrupt");
498 			close(fd);
499 		}
500 		exits(nil);
501 	}
502 }
503