xref: /plan9-contrib/sys/src/cmd/import.c (revision 4d44ba9b9ee4246ddbd96c7fcaf0918ab92ab35a)
1 #include <u.h>
2 #include <libc.h>
3 #include <auth.h>
4 #include <libsec.h>
5 
6 enum {
7 	Encnone,
8 	Encssl,
9 	Enctls,
10 };
11 
12 static char *encprotos[] = {
13 	[Encnone] =	"clear",
14 	[Encssl] =		"ssl",
15 	[Enctls] = 		"tls",
16 				nil,
17 };
18 
19 char		*keyspec = "";
20 char		*filterp;
21 char		*ealgs = "rc4_256 sha1";
22 int		encproto = Encnone;
23 char		*aan = "/bin/aan";
24 AuthInfo 	*ai;
25 int		debug;
26 
27 int	connect(char*, char*, int);
28 int	passive(void);
29 int	old9p(int);
30 void	catcher(void*, char*);
31 void	sysfatal(char*, ...);
32 void	usage(void);
33 int	filter(int, char *, char *);
34 
35 static void	mksecret(char *, uchar *);
36 
37 void
38 post(char *name, char *envname, int srvfd)
39 {
40 	int fd;
41 	char buf[32];
42 
43 	fd = create(name, OWRITE, 0600);
44 	if(fd < 0)
45 		return;
46 	sprint(buf, "%d",srvfd);
47 	if(write(fd, buf, strlen(buf)) != strlen(buf))
48 		sysfatal("srv write: %r");
49 	close(fd);
50 	putenv(envname, name);
51 }
52 
53 static int
54 lookup(char *s, char *l[])
55 {
56 	int i;
57 
58 	for (i = 0; l[i] != 0; i++)
59 		if (strcmp(l[i], s) == 0)
60 			return i;
61 	return -1;
62 }
63 
64 void
65 main(int argc, char **argv)
66 {
67 	char *mntpt;
68 	int fd, mntflags;
69 	int oldserver;
70 	char *srvpost, srvfile[64];
71 	int backwards = 0;
72 
73 	srvpost = nil;
74 	oldserver = 0;
75 	mntflags = MREPL;
76 	ARGBEGIN{
77 	case 'a':
78 		mntflags = MAFTER;
79 		break;
80 	case 'b':
81 		mntflags = MBEFORE;
82 		break;
83 	case 'c':
84 		mntflags |= MCREATE;
85 		break;
86 	case 'C':
87 		mntflags |= MCACHE;
88 		break;
89 	case 'd':
90 		debug++;
91 		break;
92 	case 'f':
93 		/* ignored but allowed for compatibility */
94 		break;
95 	case 'O':
96 	case 'o':
97 		oldserver = 1;
98 		break;
99 	case 'E':
100 		if ((encproto = lookup(EARGF(usage()), encprotos)) < 0)
101 			usage();
102 		break;
103 	case 'e':
104 		ealgs = EARGF(usage());
105 		if(*ealgs == 0 || strcmp(ealgs, "clear") == 0)
106 			ealgs = nil;
107 		break;
108 	case 'k':
109 		keyspec = EARGF(usage());
110 		break;
111 	case 'p':
112 		filterp = aan;
113 		break;
114 	case 's':
115 		srvpost = EARGF(usage());
116 		break;
117 	case 'B':
118 		backwards = 1;
119 		break;
120 	default:
121 		usage();
122 	}ARGEND;
123 
124 	mntpt = 0;		/* to shut up compiler */
125 	if(backwards){
126 		switch(argc) {
127 		default:
128 			mntpt = argv[0];
129 			break;
130 		case 0:
131 			usage();
132 		}
133 	} else {
134 		switch(argc) {
135 		case 2:
136 			mntpt = argv[1];
137 			break;
138 		case 3:
139 			mntpt = argv[2];
140 			break;
141 		default:
142 			usage();
143 		}
144 	}
145 
146 	if (encproto == Enctls)
147 		sysfatal("%s: tls has not yet been implemented\n", argv[0]);
148 
149 	notify(catcher);
150 	alarm(60*1000);
151 
152 	if(backwards)
153 		fd = passive();
154 	else
155 		fd = connect(argv[0], argv[1], oldserver);
156 
157 	if (!oldserver)
158 		fprint(fd, "impo %s %s\n", filterp? "aan": "nofilter", encprotos[encproto]);
159 
160 	if (encproto != Encnone && ealgs && ai) {
161 		uchar key[16];
162 		uchar digest[SHA1dlen];
163 		char fromclientsecret[21];
164 		char fromserversecret[21];
165 		int i;
166 
167 		memmove(key+4, ai->secret, ai->nsecret);
168 
169 		/* exchange random numbers */
170 		srand(truerand());
171 		for(i = 0; i < 4; i++)
172 			key[i] = rand();
173 		if(write(fd, key, 4) != 4)
174 			sysfatal("can't write key part: %r");
175 		if(readn(fd, key+12, 4) != 4)
176 			sysfatal("can't read key part: %r");
177 
178 		/* scramble into two secrets */
179 		sha1(key, sizeof(key), digest, nil);
180 		mksecret(fromclientsecret, digest);
181 		mksecret(fromserversecret, digest+10);
182 
183 		if (filterp)
184 			fd = filter(fd, filterp, argv[0]);
185 
186 		/* set up encryption */
187 		fd = pushssl(fd, ealgs, fromclientsecret, fromserversecret, nil);
188 		if(fd < 0)
189 			sysfatal("can't establish ssl connection: %r");
190 	}
191 	else if (filterp)
192 		fd = filter(fd, filterp, argv[0]);
193 
194 	if(srvpost){
195 		sprint(srvfile, "/srv/%s", srvpost);
196 		remove(srvfile);
197 		post(srvfile, srvpost, fd);
198 	}
199 	if(mount(fd, -1, mntpt, mntflags, "") < 0)
200 		sysfatal("can't mount %s: %r", argv[1]);
201 	alarm(0);
202 
203 	if(backwards && argc > 1){
204 		execl(argv[1], &argv[1]);
205 		sysfatal("exec: %r");
206 	}
207 	exits(0);
208 }
209 
210 void
211 catcher(void*, char *msg)
212 {
213 	if(strcmp(msg, "alarm") == 0)
214 		noted(NCONT);
215 	noted(NDFLT);
216 }
217 
218 int
219 old9p(int fd)
220 {
221 	int p[2];
222 
223 	if(pipe(p) < 0)
224 		sysfatal("pipe: %r");
225 
226 	switch(rfork(RFPROC|RFFDG|RFNAMEG)) {
227 	case -1:
228 		sysfatal("rfork srvold9p: %r");
229 	case 0:
230 		if(fd != 1){
231 			dup(fd, 1);
232 			close(fd);
233 		}
234 		if(p[0] != 0){
235 			dup(p[0], 0);
236 			close(p[0]);
237 		}
238 		close(p[1]);
239 		if(0){
240 			fd = open("/sys/log/cpu", OWRITE);
241 			if(fd != 2){
242 				dup(fd, 2);
243 				close(fd);
244 			}
245 			execl("/bin/srvold9p", "srvold9p", "-ds", 0);
246 		} else
247 			execl("/bin/srvold9p", "srvold9p", "-s", 0);
248 		sysfatal("exec srvold9p: %r");
249 	default:
250 		close(fd);
251 		close(p[0]);
252 	}
253 	return p[1];
254 }
255 
256 int
257 connect(char *system, char *tree, int oldserver)
258 {
259 	char buf[ERRMAX], dir[128], *na;
260 	int fd, n;
261 	char *authp;
262 
263 	na = netmkaddr(system, 0, "exportfs");
264 	if((fd = dial(na, 0, dir, 0)) < 0)
265 		sysfatal("can't dial %s: %r", system);
266 
267 	if(oldserver)
268 		authp = "p9sk2";
269 	else
270 		authp = "p9any";
271 
272 	ai = auth_proxy(fd, auth_getkey, "proto=%q role=client %s", authp, keyspec);
273 	if(ai == nil)
274 		sysfatal("%r: %s", system);
275 
276 	n = write(fd, tree, strlen(tree));
277 	if(n < 0)
278 		sysfatal("can't write tree: %r");
279 
280 	strcpy(buf, "can't read tree");
281 
282 	n = read(fd, buf, sizeof buf - 1);
283 	if(n!=2 || buf[0]!='O' || buf[1]!='K'){
284 		buf[sizeof buf - 1] = '\0';
285 		sysfatal("bad remote tree: %s", buf);
286 	}
287 
288 	if(oldserver)
289 		return old9p(fd);
290 	return fd;
291 }
292 
293 int
294 passive(void)
295 {
296 	int fd;
297 
298 	ai = auth_proxy(0, auth_getkey, "proto=p9any role=server");
299 	if(ai == nil)
300 		sysfatal("auth_proxy: %r");
301 	if(auth_chuid(ai, nil) < 0)
302 		sysfatal("auth_chuid: %r");
303 	putenv("service", "import");
304 
305 	fd = dup(0, -1);
306 	close(0);
307 	open("/dev/null", ORDWR);
308 	close(1);
309 	open("/dev/null", ORDWR);
310 
311 	return fd;
312 }
313 
314 void
315 usage(void)
316 {
317 	fprint(2, "usage: import [-abcC] [-E clear|ssl|tls] [-e 'crypt auth'|clear] [-k keypattern] [-p] host remotefs [mountpoint]\n");
318 	exits("usage");
319 }
320 
321 /* Network on fd1, mount driver on fd0 */
322 int
323 filter(int fd, char *cmd, char *host)
324 {
325 	int p[2], len, argc;
326 	char newport[256], buf[256], *s;
327 	char *argv[16], *file, *pbuf;
328 
329 	if ((len = read(fd, newport, sizeof newport - 1)) < 0)
330 		sysfatal("filter: cannot write port; %r\n");
331 	newport[len] = '\0';
332 
333 	if ((s = strchr(newport, '!')) == nil)
334 		sysfatal("filter: illegally formatted port %s\n", newport);
335 
336 	strecpy(buf, buf+sizeof buf, netmkaddr(host, "tcp", "0"));
337 	pbuf = strrchr(buf, '!');
338 	strecpy(pbuf, buf+sizeof buf, s);
339 
340 	if(debug)
341 		fprint(2, "filter: remote port %s\n", newport);
342 
343 	argc = tokenize(cmd, argv, nelem(argv)-2);
344 	if (argc == 0)
345 		sysfatal("filter: empty command");
346 	argv[argc++] = "-c";
347 	argv[argc++] = buf;
348 	argv[argc] = nil;
349 	file = argv[0];
350 	if (s = strrchr(argv[0], '/'))
351 		argv[0] = s+1;
352 
353 	if(pipe(p) < 0)
354 		sysfatal("pipe: %r");
355 
356 	switch(rfork(RFNOWAIT|RFPROC|RFFDG)) {
357 	case -1:
358 		sysfatal("rfork record module: %r");
359 	case 0:
360 		dup(p[0], 1);
361 		dup(p[0], 0);
362 		close(p[0]);
363 		close(p[1]);
364 		exec(file, argv);
365 		sysfatal("exec record module: %r");
366 	default:
367 		close(fd);
368 		close(p[0]);
369 	}
370 	return p[1];
371 }
372 
373 static void
374 mksecret(char *t, uchar *f)
375 {
376 	sprint(t, "%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux",
377 		f[0], f[1], f[2], f[3], f[4], f[5], f[6], f[7], f[8], f[9]);
378 }
379