xref: /inferno-os/os/boot/rpcg/bootp.c (revision 4eb166cf184c1f102fb79e31b1465ea3e2021c39)
1 #include "u.h"
2 #include "lib.h"
3 #include "mem.h"
4 #include "dat.h"
5 #include "fns.h"
6 #include "io.h"
7 
8 #include "ip.h"
9 
10 #define	XPADDR(a)	((ulong)(a) & ~KSEGM)
11 
12 enum {
13 	CHECKSUM = 1,	/* set zero if trouble booting from Linux */
14 };
15 
16 uchar broadcast[Eaddrlen] = {
17 	0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
18 };
19 
20 static ushort tftpport = 5000;
21 static int Id = 1;
22 static Netaddr myaddr;
23 static Netaddr server;
24 
25 typedef struct {
26 	uchar	header[4];
27 	uchar	data[Segsize];
28 } Tftp;
29 static Tftp tftpb;
30 
31 static void
32 hnputs(uchar *ptr, ushort val)
33 {
34 	ptr[0] = val>>8;
35 	ptr[1] = val;
36 }
37 
38 static void
39 hnputl(uchar *ptr, ulong val)
40 {
41 	ptr[0] = val>>24;
42 	ptr[1] = val>>16;
43 	ptr[2] = val>>8;
44 	ptr[3] = val;
45 }
46 
47 static ulong
48 nhgetl(uchar *ptr)
49 {
50 	return ((ptr[0]<<24) | (ptr[1]<<16) | (ptr[2]<<8) | ptr[3]);
51 }
52 
53 static ushort
54 nhgets(uchar *ptr)
55 {
56 	return ((ptr[0]<<8) | ptr[1]);
57 }
58 
59 static	short	endian	= 1;
60 static	char*	aendian	= (char*)&endian;
61 #define	LITTLE	*aendian
62 
63 static ushort
64 ptcl_csum(void *a, int len)
65 {
66 	uchar *addr;
67 	ulong t1, t2;
68 	ulong losum, hisum, mdsum, x;
69 
70 	addr = a;
71 	losum = 0;
72 	hisum = 0;
73 	mdsum = 0;
74 
75 	x = 0;
76 	if((ulong)addr & 1) {
77 		if(len) {
78 			hisum += addr[0];
79 			len--;
80 			addr++;
81 		}
82 		x = 1;
83 	}
84 	while(len >= 16) {
85 		t1 = *(ushort*)(addr+0);
86 		t2 = *(ushort*)(addr+2);	mdsum += t1;
87 		t1 = *(ushort*)(addr+4);	mdsum += t2;
88 		t2 = *(ushort*)(addr+6);	mdsum += t1;
89 		t1 = *(ushort*)(addr+8);	mdsum += t2;
90 		t2 = *(ushort*)(addr+10);	mdsum += t1;
91 		t1 = *(ushort*)(addr+12);	mdsum += t2;
92 		t2 = *(ushort*)(addr+14);	mdsum += t1;
93 		mdsum += t2;
94 		len -= 16;
95 		addr += 16;
96 	}
97 	while(len >= 2) {
98 		mdsum += *(ushort*)addr;
99 		len -= 2;
100 		addr += 2;
101 	}
102 	if(x) {
103 		if(len)
104 			losum += addr[0];
105 		if(LITTLE)
106 			losum += mdsum;
107 		else
108 			hisum += mdsum;
109 	} else {
110 		if(len)
111 			hisum += addr[0];
112 		if(LITTLE)
113 			hisum += mdsum;
114 		else
115 			losum += mdsum;
116 	}
117 
118 	losum += hisum >> 8;
119 	losum += (hisum & 0xff) << 8;
120 	while(hisum = losum>>16)
121 		losum = hisum + (losum & 0xffff);
122 
123 	return ~losum;
124 }
125 
126 static ushort
127 ip_csum(uchar *addr)
128 {
129 	int len;
130 	ulong sum = 0;
131 
132 	len = (addr[0]&0xf)<<2;
133 
134 	while(len > 0) {
135 		sum += addr[0]<<8 | addr[1] ;
136 		len -= 2;
137 		addr += 2;
138 	}
139 
140 	sum = (sum & 0xffff) + (sum >> 16);
141 	sum = (sum & 0xffff) + (sum >> 16);
142 	return (sum^0xffff);
143 }
144 
145 static void
146 udpsend(int ctlrno, Netaddr *a, void *data, int dlen)
147 {
148 	Udphdr *uh;
149 	Etherhdr *ip;
150 	static Etherpkt pkt;
151 	int len, ptcllen;
152 
153 
154 	uh = (Udphdr*)&pkt;
155 
156 	memset(uh, 0, sizeof(Etherpkt));
157 	memmove(uh->udpcksum+sizeof(uh->udpcksum), data, dlen);
158 
159 	/*
160 	 * UDP portion
161 	 */
162 	ptcllen = dlen + (UDP_HDRSIZE-UDP_PHDRSIZE);
163 	uh->ttl = 0;
164 	uh->udpproto = IP_UDPPROTO;
165 	uh->frag[0] = 0;
166 	uh->frag[1] = 0;
167 	hnputs(uh->udpplen, ptcllen);
168 	hnputl(uh->udpsrc, myaddr.ip);
169 	hnputs(uh->udpsport, myaddr.port);
170 	hnputl(uh->udpdst, a->ip);
171 	hnputs(uh->udpdport, a->port);
172 	hnputs(uh->udplen, ptcllen);
173 	uh->udpcksum[0] = 0;
174 	uh->udpcksum[1] = 0;
175 	/*dlen = (dlen+1)&~1; */
176 	hnputs(uh->udpcksum, ptcl_csum(&uh->ttl, dlen+UDP_HDRSIZE));
177 
178 	/*
179 	 * IP portion
180 	 */
181 	ip = (Etherhdr*)&pkt;
182 	len = sizeof(Udphdr)+dlen;
183 	ip->vihl = IP_VER|IP_HLEN;
184 	ip->tos = 0;
185 	ip->ttl = 255;
186 	hnputs(ip->length, len-ETHER_HDR);
187 	hnputs(ip->id, Id++);
188 	ip->frag[0] = 0;
189 	ip->frag[1] = 0;
190 	ip->cksum[0] = 0;
191 	ip->cksum[1] = 0;
192 	hnputs(ip->cksum, ip_csum(&ip->vihl));
193 
194 	/*
195 	 * Ethernet MAC portion
196 	 */
197 	hnputs(ip->type, ET_IP);
198 	memmove(ip->d, a->ea, sizeof(ip->d));
199 
200 	ethertxpkt(ctlrno, &pkt, len, Timeout);
201 }
202 
203 static void
204 nak(int ctlrno, Netaddr *a, int code, char *msg, int report)
205 {
206 	int n;
207 	char buf[128];
208 
209 	buf[0] = 0;
210 	buf[1] = Tftp_ERROR;
211 	buf[2] = 0;
212 	buf[3] = code;
213 	strcpy(buf+4, msg);
214 	n = strlen(msg) + 4 + 1;
215 	udpsend(ctlrno, a, buf, n);
216 	if(report)
217 		print("\ntftp: error(%d): %s\n", code, msg);
218 }
219 
220 static int
221 udprecv(int ctlrno, Netaddr *a, void *data, int dlen)
222 {
223 	int n, len;
224 	ushort csm;
225 	Udphdr *h;
226 	ulong addr, timo;
227 	static Etherpkt pkt;
228 	static int rxactive;
229 
230 	if(rxactive == 0)
231 		timo = 1000;
232 	else
233 		timo = Timeout;
234 	timo += TK2MS(m->ticks);
235 	while(timo > TK2MS(m->ticks)){
236 		n = etherrxpkt(ctlrno, &pkt, timo-TK2MS(m->ticks));
237 		if(n <= 0)
238 			continue;
239 
240 		h = (Udphdr*)&pkt;
241 		if(nhgets(h->type) != ET_IP)
242 			continue;
243 
244 		if(ip_csum(&h->vihl)) {
245 			print("ip chksum error\n");
246 			continue;
247 		}
248 		if(h->vihl != (IP_VER|IP_HLEN)) {
249 			print("ip bad vers/hlen\n");
250 			continue;
251 		}
252 
253 		if(h->udpproto != IP_UDPPROTO)
254 			continue;
255 
256 		h->ttl = 0;
257 		len = nhgets(h->udplen);
258 		hnputs(h->udpplen, len);
259 
260 		if(CHECKSUM && nhgets(h->udpcksum)) {
261 			csm = ptcl_csum(&h->ttl, len+UDP_PHDRSIZE);
262 			if(csm != 0) {
263 				print("udp chksum error csum #%4lux len %d\n", csm, n);
264 				break;
265 			}
266 		}
267 
268 		if(a->port != 0 && nhgets(h->udpsport) != a->port)
269 			continue;
270 
271 		addr = nhgetl(h->udpsrc);
272 		if(a->ip != Bcastip && addr != a->ip)
273 			continue;
274 
275 		len -= UDP_HDRSIZE-UDP_PHDRSIZE;
276 		if(len > dlen) {
277 			print("udp: packet too big\n");
278 			continue;
279 		}
280 
281 		memmove(data, h->udpcksum+sizeof(h->udpcksum), len);
282 		a->ip = addr;
283 		a->port = nhgets(h->udpsport);
284 		memmove(a->ea, pkt.s, sizeof(a->ea));
285 
286 		rxactive = 1;
287 		return len;
288 	}
289 
290 	return 0;
291 }
292 
293 static int tftpblockno;
294 
295 static int
296 tftpopen(int ctlrno, Netaddr *a, char *name, Tftp *tftp)
297 {
298 	int i, len, rlen, oport;
299 	char buf[Segsize+2];
300 
301 	buf[0] = 0;
302 	buf[1] = Tftp_READ;
303 	len = sprint(buf+2, "%s", name) + 2;
304 	len += sprint(buf+len+1, "octet") + 2;
305 
306 	oport = a->port;
307 	for(i = 0; i < 5; i++){
308 		a->port = oport;
309 		udpsend(ctlrno, a, buf, len);
310 		a->port = 0;
311 		if((rlen = udprecv(ctlrno, a, tftp, sizeof(Tftp))) < sizeof(tftp->header))
312 			continue;
313 
314 		switch((tftp->header[0]<<8)|tftp->header[1]){
315 
316 		case Tftp_ERROR:
317 			print("tftpopen: error (%d): %s\n",
318 				(tftp->header[2]<<8)|tftp->header[3], tftp->data);
319 			return -1;
320 
321 		case Tftp_DATA:
322 			tftpblockno = 1;
323 			len = (tftp->header[2]<<8)|tftp->header[3];
324 			if(len != tftpblockno){
325 				print("tftpopen: block error: %d\n", len);
326 				nak(ctlrno, a, 1, "block error", 0);
327 				return -1;
328 			}
329 			return rlen-sizeof(tftp->header);
330 		}
331 	}
332 
333 	print("tftpopen: failed to connect to server\n");
334 	return -1;
335 }
336 
337 static int
338 tftpread(int ctlrno, Netaddr *a, Tftp *tftp, int dlen)
339 {
340 	int blockno, len, retry;
341 	uchar buf[4];
342 
343 	buf[0] = 0;
344 	buf[1] = Tftp_ACK;
345 	buf[2] = tftpblockno>>8;
346 	buf[3] = tftpblockno;
347 	tftpblockno++;
348 
349 	dlen += sizeof(tftp->header);
350 
351 	retry = 0;
352 buggery:
353 	udpsend(ctlrno, a, buf, sizeof(buf));
354 
355 	if((len = udprecv(ctlrno, a, tftp, dlen)) < dlen){
356 		print("tftpread: %d != %d\n", len, dlen);
357 		nak(ctlrno, a, 2, "short read", 0);
358 		if(retry++ < 5)
359 			goto buggery;
360 		return -1;
361 	}
362 
363 	blockno = (tftp->header[2]<<8)|tftp->header[3];
364 	if(blockno != tftpblockno){
365 		print("?");
366 
367 		if(blockno == tftpblockno-1 && retry++ < 8)
368 			goto buggery;
369 		print("tftpread: block error: %d, expected %d\n", blockno, tftpblockno);
370 		nak(ctlrno, a, 1, "block error", 0);
371 
372 		return -1;
373 	}
374 
375 	return len-sizeof(tftp->header);
376 }
377 
378 int
379 bootp(int ctlrno, char *file)
380 {
381 	Bootp req, rep;
382 	int i, dlen, segsize, text, data, bss, total;
383 	uchar *ea, *addr, *p;
384 	ulong entry;
385 	Exec *exec;
386 	char name[128], *filename, *sysname;
387 
388 	if((ea = etheraddr(ctlrno)) == 0){
389 		print("invalid ctlrno %d\n", ctlrno);
390 		return -1;
391 	}
392 
393 	filename = 0;
394 	sysname = 0;
395 	if(file && *file){
396 		strcpy(name, file);
397 		if(filename = strchr(name, ':')){
398 			if(filename != name && *(filename-1) != '\\'){
399 				sysname = name;
400 				*filename++ = 0;
401 			}
402 		}
403 		else
404 			filename = name;
405 	}
406 
407 
408 	memset(&req, 0, sizeof(req));
409 	req.op = Bootrequest;
410 	req.htype = 1;			/* ethernet */
411 	req.hlen = Eaddrlen;		/* ethernet */
412 	memmove(req.chaddr, ea, Eaddrlen);
413 
414 	myaddr.ip = 0;
415 	myaddr.port = BPportsrc;
416 	memmove(myaddr.ea, ea, Eaddrlen);
417 
418 	for(i = 0; i < 10; i++) {
419 		server.ip = Bcastip;
420 		server.port = BPportdst;
421 		memmove(server.ea, broadcast, sizeof(server.ea));
422 		udpsend(ctlrno, &server, &req, sizeof(req));
423 		if(udprecv(ctlrno, &server, &rep, sizeof(rep)) <= 0)
424 			continue;
425 		if(memcmp(req.chaddr, rep.chaddr, Eaddrlen))
426 			continue;
427 		if(rep.htype != 1 || rep.hlen != Eaddrlen)
428 			continue;
429 		if(sysname == 0 || strcmp(sysname, rep.sname) == 0)
430 			break;
431 	}
432 	if(i >= 10) {
433 		print("bootp timed out\n");
434 		return -1;
435 	}
436 
437 	if(filename == 0 || *filename == 0)
438 		filename = rep.file;
439 
440 	if(rep.sname[0] != '\0')
441 		 print("%s ", rep.sname);
442 	print("(%d.%d.%d.%d!%d): %s\n",
443 		rep.siaddr[0],
444 		rep.siaddr[1],
445 		rep.siaddr[2],
446 		rep.siaddr[3],
447 		server.port,
448 		filename);uartwait();
449 
450 	myaddr.ip = nhgetl(rep.yiaddr);
451 	myaddr.port = tftpport++;
452 	server.ip = nhgetl(rep.siaddr);
453 	server.port = TFTPport;
454 
455 	if((dlen = tftpopen(ctlrno, &server, filename, &tftpb)) < 0)
456 		return -1;
457 	exec = (Exec*)(tftpb.data);
458 	if(dlen < sizeof(Exec) || GLLONG(exec->magic) != Q_MAGIC){
459 		nak(ctlrno, &server, 0, "bad magic number", 1);
460 		return -1;
461 	}
462 	text = GLLONG(exec->text);
463 	data = GLLONG(exec->data);
464 	bss = GLLONG(exec->bss);
465 	total = text+data+bss;
466 	entry = GLLONG(exec->entry);
467 print("load@%8.8lux: ", XPADDR(entry));uartwait();
468 	print("%d", text);
469 
470 	addr = (uchar*)XPADDR(entry);
471 	p = tftpb.data+sizeof(Exec);
472 	dlen -= sizeof(Exec);
473 	segsize = text;
474 	for(;;){
475 		if(dlen == 0){
476 			if((dlen = tftpread(ctlrno, &server, &tftpb, sizeof(tftpb.data))) < 0)
477 				return -1;
478 			p = tftpb.data;
479 		}
480 		if(segsize <= dlen)
481 			i = segsize;
482 		else
483 			i = dlen;
484 		memmove(addr, p, i);
485 
486 		addr += i;
487 		p += i;
488 		segsize -= i;
489 		dlen -= i;
490 
491 		if(segsize <= 0){
492 			if(data == 0)
493 				break;
494 			print("+%d", data);
495 			segsize = data;
496 			data = 0;
497 			addr = (uchar*)PGROUND((ulong)addr);
498 		}
499 	}
500 	nak(ctlrno, &server, 3, "ok", 0);		/* tftpclose */
501 	print("+%d=%d\n", bss, total);
502 	print("entry: 0x%lux\n", entry);
503 	uartwait();
504 	scc2stop();
505 	splhi();
506 	(*(void(*)(void))(XPADDR(entry)))();
507 
508 	return 0;
509 }
510