xref: /plan9-contrib/sys/src/cmd/ip/tftpd.c (revision d46c239f8612929b7dbade67d0d071633df3a15d)
1 #include <u.h>
2 #include <libc.h>
3 #include <bio.h>
4 #include <ip.h>
5 #include <ndb.h>
6 
7 enum
8 {
9 	Maxpath=	128,
10 	Maxerr=		256,
11 };
12 
13 int 	dbg;
14 int	restricted;
15 void	sendfile(int, char*, char*);
16 void	recvfile(int, char*, char*);
17 void	nak(int, int, char*);
18 void	ack(int, ushort);
19 void	clrcon(void);
20 void	setuser(void);
21 char*	sunkernel(char*);
22 void	remoteaddr(char*, char*, int);
23 void	doserve(int);
24 
25 char	bigbuf[32768];
26 char	raddr[64];
27 
28 char	*dir = "/lib/tftpd";
29 char	*dirsl;
30 int	dirsllen;
31 char	flog[] = "ipboot";
32 char	net[Maxpath];
33 
34 enum
35 {
36 	Tftp_READ	= 1,
37 	Tftp_WRITE	= 2,
38 	Tftp_DATA	= 3,
39 	Tftp_ACK	= 4,
40 	Tftp_ERROR	= 5,
41 	Segsize		= 512,
42 };
43 
44 void
45 usage(void)
46 {
47 	fprint(2, "usage: %s [-dr] [-h homedir] [-x netmtpt]\n", argv0);
48 	exits("usage");
49 }
50 
51 void
52 main(int argc, char **argv)
53 {
54 	char buf[64];
55 	char adir[64], ldir[64];
56 	int cfd, lcfd, dfd;
57 	char *p;
58 
59 	setnetmtpt(net, sizeof(net), nil);
60 	ARGBEGIN{
61 	case 'd':
62 		dbg++;
63 		break;
64 	case 'h':
65 		dir = ARGF();
66 		break;
67 	case 'r':
68 		restricted = 1;
69 		break;
70 	case 'x':
71 		p = ARGF();
72 		if(p == nil)
73 			usage();
74 		setnetmtpt(net, sizeof(net), p);
75 		break;
76 	default:
77 		usage();
78 	}ARGEND
79 
80 	snprint(buf, sizeof buf, "%s/", dir);
81 	dirsl = strdup(buf);
82 	dirsllen = strlen(dirsl);
83 
84 	fmtinstall('E', eipfmt);
85 	fmtinstall('I', eipfmt);
86 
87 	if(chdir(dir) < 0)
88 		sysfatal("can't get to directory %s: %r", dir);
89 
90 	if(!dbg)
91 		switch(rfork(RFNOTEG|RFPROC|RFFDG)) {
92 		case -1:
93 			sysfatal("fork: %r");
94 		case 0:
95 			break;
96 		default:
97 			exits(0);
98 		}
99 
100 	syslog(dbg, flog, "started");
101 
102 	sprint(buf, "%s/udp!*!69", net);
103 	cfd = announce(buf, adir);
104 	setuser();
105 	for(;;) {
106 		lcfd = listen(adir, ldir);
107 		if(lcfd < 0)
108 			sysfatal("listening: %r");
109 
110 		switch(fork()) {
111 		case -1:
112 			sysfatal("fork: %r");
113 		case 0:
114 			dfd = accept(cfd, ldir);
115 			if(dfd < 0)
116  				exits(0);
117 			remoteaddr(ldir, raddr, sizeof(raddr));
118 			doserve(dfd);
119 			exits("done");
120 			break;
121 		default:
122 			close(lcfd);
123 			continue;
124 		}
125 	}
126 }
127 
128 void
129 doserve(int fd)
130 {
131 	int dlen;
132 	char *mode, *p;
133 	short op;
134 
135 	dlen = read(fd, bigbuf, sizeof(bigbuf));
136 	if(dlen < 0)
137 		sysfatal("listen read: %r");
138 
139 	op = (bigbuf[0]<<8) | bigbuf[1];
140 	dlen -= 2;
141 	mode = bigbuf+2;
142 	while(*mode != '\0' && dlen--)
143 		mode++;
144 	mode++;
145 	p = mode;
146 	while(*p && dlen--)
147 		p++;
148 	if(dlen == 0) {
149 		nak(fd, 0, "bad tftpmode");
150 		close(fd);
151 		syslog(dbg, flog, "bad mode %s", raddr);
152 		return;
153 	}
154 
155 	if(op != Tftp_READ && op != Tftp_WRITE) {
156 		nak(fd, 4, "Illegal TFTP operation");
157 		close(fd);
158 		syslog(dbg, flog, "bad request %d %s", op, raddr);
159 		return;
160 	}
161 
162 	if(restricted){
163 		if(strncmp(bigbuf+2, "../", 3) || strstr(bigbuf+2, "/../") ||
164 		  (bigbuf[2] == '/' && strncmp(bigbuf+2, dirsl, dirsllen)!=0)){
165 			nak(fd, 4, "Permission denied");
166 			close(fd);
167 			syslog(dbg, flog, "bad request %d %s", op, raddr);
168 			return;
169 		}
170 	}
171 
172 	if(op == Tftp_READ)
173 		sendfile(fd, bigbuf+2, mode);
174 	else
175 		recvfile(fd, bigbuf+2, mode);
176 }
177 
178 void
179 catcher(void *junk, char *msg)
180 {
181 	USED(junk);
182 
183 	if(strncmp(msg, "exit", 4) == 0)
184 		noted(NDFLT);
185 	noted(NCONT);
186 }
187 
188 void
189 sendfile(int fd, char *name, char *mode)
190 {
191 	int file;
192 	uchar buf[Segsize+4];
193 	uchar ack[1024];
194 	char errbuf[Maxerr];
195 	int ackblock, block, ret;
196 	int rexmit, n, al, txtry, rxl;
197 	short op;
198 
199 	syslog(dbg, flog, "send file '%s' %s to %s", name, mode, raddr);
200 	name = sunkernel(name);
201 	if(name == 0){
202 		nak(fd, 0, "not in our database");
203 		return;
204 	}
205 
206 	notify(catcher);
207 
208 	file = open(name, OREAD);
209 	if(file < 0) {
210 		errstr(errbuf, sizeof errbuf);
211 		nak(fd, 0, errbuf);
212 		return;
213 	}
214 	block = 0;
215 	rexmit = 0;
216 	n = 0;
217 	for(txtry = 0; txtry < 5;) {
218 		if(rexmit == 0) {
219 			block++;
220 			buf[0] = 0;
221 			buf[1] = Tftp_DATA;
222 			buf[2] = block>>8;
223 			buf[3] = block;
224 			n = read(file, buf+4, Segsize);
225 			if(n < 0) {
226 				errstr(errbuf, sizeof errbuf);
227 				nak(fd, 0, errbuf);
228 				return;
229 			}
230 			txtry = 0;
231 		}
232 		else {
233 			syslog(dbg, flog, "rexmit %d %s:%d to %s", 4+n, name, block, raddr);
234 			txtry++;
235 		}
236 
237 		ret = write(fd, buf, 4+n);
238 		if(ret < 0)
239 			sysfatal("tftpd: network write error: %r");
240 
241 		for(rxl = 0; rxl < 10; rxl++) {
242 			rexmit = 0;
243 			alarm(500);
244 			al = read(fd, ack, sizeof(ack));
245 			alarm(0);
246 			if(al < 0) {
247 				rexmit = 1;
248 				break;
249 			}
250 			op = ack[0]<<8|ack[1];
251 			if(op == Tftp_ERROR)
252 				goto error;
253 			ackblock = ack[2]<<8|ack[3];
254 			if(ackblock == block)
255 				break;
256 			if(ackblock == 0xffff) {
257 				rexmit = 1;
258 				break;
259 			}
260 		}
261 		if(ret != Segsize+4 && rexmit == 0)
262 			break;
263 	}
264 error:
265 	close(fd);
266 	close(file);
267 }
268 
269 void
270 recvfile(int fd, char *name, char *mode)
271 {
272 	ushort op, block, inblock;
273 	uchar buf[Segsize+8];
274 	char errbuf[Maxerr];
275 	int n, ret, file;
276 
277 	syslog(dbg, flog, "receive file '%s' %s from %s", name, mode, raddr);
278 
279 	file = create(name, OWRITE, 0666);
280 	if(file < 0) {
281 		errstr(errbuf, sizeof errbuf);
282 		nak(fd, 0, errbuf);
283 		return;
284 	}
285 
286 	block = 0;
287 	ack(fd, block);
288 	block++;
289 
290 	for(;;) {
291 		alarm(15000);
292 		n = read(fd, buf, sizeof(buf));
293 		alarm(0);
294 		if(n < 0)
295 			goto error;
296 		op = buf[0]<<8|buf[1];
297 		if(op == Tftp_ERROR)
298 			goto error;
299 
300 		n -= 4;
301 		inblock = buf[2]<<8|buf[3];
302 		if(op == Tftp_DATA) {
303 			if(inblock == block) {
304 				ret = write(file, buf, n);
305 				if(ret < 0) {
306 					errstr(errbuf, sizeof errbuf);
307 					nak(fd, 0, errbuf);
308 					goto error;
309 				}
310 				ack(fd, block);
311 				block++;
312 			}
313 			ack(fd, 0xffff);
314 		}
315 	}
316 error:
317 	close(file);
318 }
319 
320 void
321 ack(int fd, ushort block)
322 {
323 	uchar ack[4];
324 	int n;
325 
326 	ack[0] = 0;
327 	ack[1] = Tftp_ACK;
328 	ack[2] = block>>8;
329 	ack[3] = block;
330 
331 	n = write(fd, ack, 4);
332 	if(n < 0)
333 		sysfatal("network write: %r");
334 }
335 
336 void
337 nak(int fd, int code, char *msg)
338 {
339 	char buf[128];
340 	int n;
341 
342 	buf[0] = 0;
343 	buf[1] = Tftp_ERROR;
344 	buf[2] = 0;
345 	buf[3] = code;
346 	strcpy(buf+4, msg);
347 	n = strlen(msg) + 4 + 1;
348 	n = write(fd, buf, n);
349 	if(n < 0)
350 		sysfatal("write nak: %r");
351 }
352 
353 void
354 setuser(void)
355 {
356 	int f;
357 
358 	f = open("/dev/user", OWRITE);
359 	if(f < 0)
360 		return;
361 	write(f, "none", sizeof("none"));
362 	close(f);
363 }
364 
365 char*
366 lookup(char *sattr, char *sval, char *tattr, char *tval)
367 {
368 	static Ndb *db;
369 	char *attrs[1];
370 	Ndbtuple *t;
371 
372 	if(db == nil)
373 		db = ndbopen(0);
374 	if(db == nil)
375 		return nil;
376 
377 	if(sattr == nil)
378 		sattr = ipattr(sval);
379 
380 	attrs[0] = tattr;
381 	t = ndbipinfo(db, sattr, sval, attrs, 1);
382 	if(t == nil)
383 		return nil;
384 	strcpy(tval, t->val);
385 	ndbfree(t);
386 	return tval;
387 }
388 
389 /*
390  *  for sun kernel boots, replace the requested file name with
391  *  a one from our database.  If the database doesn't specify a file,
392  *  don't answer.
393  */
394 char*
395 sunkernel(char *name)
396 {
397 	ulong addr;
398 	uchar v4[IPv4addrlen];
399 	uchar v6[IPaddrlen];
400 	char buf[Ndbvlen];
401 	char ipbuf[Ndbvlen];
402 
403 	if(strlen(name) != 14 || strncmp(name + 8, ".SUN", 4) != 0)
404 		return name;
405 
406 	addr = strtoul(name, 0, 16);
407 	v4[0] = addr>>24;
408 	v4[1] = addr>>16;
409 	v4[2] = addr>>8;
410 	v4[3] = addr;
411 	v4tov6(v6, v4);
412 	sprint(ipbuf, "%I", v6);
413 	return lookup("ip", ipbuf, "bootf", buf);
414 }
415 
416 void
417 remoteaddr(char *dir, char *raddr, int len)
418 {
419 	char buf[64];
420 	int fd, n;
421 
422 	snprint(buf, sizeof(buf), "%s/remote", dir);
423 	fd = open(buf, OREAD);
424 	if(fd < 0){
425 		snprint(raddr, sizeof(raddr), "unknown");
426 		return;
427 	}
428 	n = read(fd, raddr, len-1);
429 	close(fd);
430 	if(n <= 0){
431 		snprint(raddr, sizeof(raddr), "unknown");
432 		return;
433 	}
434 	if(n > 0)
435 		n--;
436 	raddr[n] = 0;
437 }
438