xref: /inferno-os/os/boot/mpc/bootp.c (revision 9dc22068e29604f4b484e746112a9a4efe6fd57f)
1 #include "u.h"
2 #include "lib.h"
3 #include "mem.h"
4 #include "dat.h"
5 #include "fns.h"
6 #include "io.h"
7 
8 #include "ip.h"
9 
10 enum {
11 	CHECKSUM = 1,	/* set zero if trouble booting from Linux */
12 };
13 
14 uchar broadcast[Eaddrlen] = {
15 	0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
16 };
17 
18 static ushort tftpport = 5000;
19 static int Id = 1;
20 static Netaddr myaddr;
21 static Netaddr server;
22 
23 typedef struct {
24 	uchar	header[4];
25 	uchar	data[Segsize];
26 } Tftp;
27 static Tftp tftpb;
28 
29 static void
30 hnputs(uchar *ptr, ushort val)
31 {
32 	ptr[0] = val>>8;
33 	ptr[1] = val;
34 }
35 
36 static void
37 hnputl(uchar *ptr, ulong val)
38 {
39 	ptr[0] = val>>24;
40 	ptr[1] = val>>16;
41 	ptr[2] = val>>8;
42 	ptr[3] = val;
43 }
44 
45 static ulong
46 nhgetl(uchar *ptr)
47 {
48 	return ((ptr[0]<<24) | (ptr[1]<<16) | (ptr[2]<<8) | ptr[3]);
49 }
50 
51 static ushort
52 nhgets(uchar *ptr)
53 {
54 	return ((ptr[0]<<8) | ptr[1]);
55 }
56 
57 static	short	endian	= 1;
58 static	char*	aendian	= (char*)&endian;
59 #define	LITTLE	*aendian
60 
61 static ushort
62 ptcl_csum(void *a, int len)
63 {
64 	uchar *addr;
65 	ulong t1, t2;
66 	ulong losum, hisum, mdsum, x;
67 
68 	addr = a;
69 	losum = 0;
70 	hisum = 0;
71 	mdsum = 0;
72 
73 	x = 0;
74 	if((ulong)addr & 1) {
75 		if(len) {
76 			hisum += addr[0];
77 			len--;
78 			addr++;
79 		}
80 		x = 1;
81 	}
82 	while(len >= 16) {
83 		t1 = *(ushort*)(addr+0);
84 		t2 = *(ushort*)(addr+2);	mdsum += t1;
85 		t1 = *(ushort*)(addr+4);	mdsum += t2;
86 		t2 = *(ushort*)(addr+6);	mdsum += t1;
87 		t1 = *(ushort*)(addr+8);	mdsum += t2;
88 		t2 = *(ushort*)(addr+10);	mdsum += t1;
89 		t1 = *(ushort*)(addr+12);	mdsum += t2;
90 		t2 = *(ushort*)(addr+14);	mdsum += t1;
91 		mdsum += t2;
92 		len -= 16;
93 		addr += 16;
94 	}
95 	while(len >= 2) {
96 		mdsum += *(ushort*)addr;
97 		len -= 2;
98 		addr += 2;
99 	}
100 	if(x) {
101 		if(len)
102 			losum += addr[0];
103 		if(LITTLE)
104 			losum += mdsum;
105 		else
106 			hisum += mdsum;
107 	} else {
108 		if(len)
109 			hisum += addr[0];
110 		if(LITTLE)
111 			hisum += mdsum;
112 		else
113 			losum += mdsum;
114 	}
115 
116 	losum += hisum >> 8;
117 	losum += (hisum & 0xff) << 8;
118 	while(hisum = losum>>16)
119 		losum = hisum + (losum & 0xffff);
120 
121 	return ~losum;
122 }
123 
124 static ushort
125 ip_csum(uchar *addr)
126 {
127 	int len;
128 	ulong sum = 0;
129 
130 	len = (addr[0]&0xf)<<2;
131 
132 	while(len > 0) {
133 		sum += addr[0]<<8 | addr[1] ;
134 		len -= 2;
135 		addr += 2;
136 	}
137 
138 	sum = (sum & 0xffff) + (sum >> 16);
139 	sum = (sum & 0xffff) + (sum >> 16);
140 	return (sum^0xffff);
141 }
142 
143 static void
144 udpsend(int ctlrno, Netaddr *a, void *data, int dlen)
145 {
146 	Udphdr *uh;
147 	Etherhdr *ip;
148 	static Etherpkt pkt;
149 	int len, ptcllen;
150 
151 
152 	uh = (Udphdr*)&pkt;
153 
154 	memset(uh, 0, sizeof(Etherpkt));
155 	memmove(uh->udpcksum+sizeof(uh->udpcksum), data, dlen);
156 
157 	/*
158 	 * UDP portion
159 	 */
160 	ptcllen = dlen + (UDP_HDRSIZE-UDP_PHDRSIZE);
161 	uh->ttl = 0;
162 	uh->udpproto = IP_UDPPROTO;
163 	uh->frag[0] = 0;
164 	uh->frag[1] = 0;
165 	hnputs(uh->udpplen, ptcllen);
166 	hnputl(uh->udpsrc, myaddr.ip);
167 	hnputs(uh->udpsport, myaddr.port);
168 	hnputl(uh->udpdst, a->ip);
169 	hnputs(uh->udpdport, a->port);
170 	hnputs(uh->udplen, ptcllen);
171 	uh->udpcksum[0] = 0;
172 	uh->udpcksum[1] = 0;
173 	/*dlen = (dlen+1)&~1; */
174 	hnputs(uh->udpcksum, ptcl_csum(&uh->ttl, dlen+UDP_HDRSIZE));
175 
176 	/*
177 	 * IP portion
178 	 */
179 	ip = (Etherhdr*)&pkt;
180 	len = sizeof(Udphdr)+dlen;
181 	ip->vihl = IP_VER|IP_HLEN;
182 	ip->tos = 0;
183 	ip->ttl = 255;
184 	hnputs(ip->length, len-ETHER_HDR);
185 	hnputs(ip->id, Id++);
186 	ip->frag[0] = 0;
187 	ip->frag[1] = 0;
188 	ip->cksum[0] = 0;
189 	ip->cksum[1] = 0;
190 	hnputs(ip->cksum, ip_csum(&ip->vihl));
191 
192 	/*
193 	 * Ethernet MAC portion
194 	 */
195 	hnputs(ip->type, ET_IP);
196 	memmove(ip->d, a->ea, sizeof(ip->d));
197 
198 	ethertxpkt(ctlrno, &pkt, len, Timeout);
199 }
200 
201 static void
202 nak(int ctlrno, Netaddr *a, int code, char *msg, int report)
203 {
204 	int n;
205 	char buf[128];
206 
207 	buf[0] = 0;
208 	buf[1] = Tftp_ERROR;
209 	buf[2] = 0;
210 	buf[3] = code;
211 	strcpy(buf+4, msg);
212 	n = strlen(msg) + 4 + 1;
213 	udpsend(ctlrno, a, buf, n);
214 	if(report)
215 		print("\ntftp: error(%d): %s\n", code, msg);
216 }
217 
218 static int
219 udprecv(int ctlrno, Netaddr *a, void *data, int dlen)
220 {
221 	int n, len;
222 	ushort csm;
223 	Udphdr *h;
224 	ulong addr, timo;
225 	static Etherpkt pkt;
226 	static int rxactive;
227 
228 	if(rxactive == 0)
229 		timo = 1000;
230 	else
231 		timo = Timeout;
232 	timo += TK2MS(m->ticks);
233 	while(timo > TK2MS(m->ticks)){
234 		n = etherrxpkt(ctlrno, &pkt, timo-TK2MS(m->ticks));
235 		if(n <= 0)
236 			continue;
237 
238 		h = (Udphdr*)&pkt;
239 		if(nhgets(h->type) != ET_IP)
240 			continue;
241 
242 		if(ip_csum(&h->vihl)) {
243 			print("ip chksum error\n");
244 			continue;
245 		}
246 		if(h->vihl != (IP_VER|IP_HLEN)) {
247 			print("ip bad vers/hlen\n");
248 			continue;
249 		}
250 
251 		if(h->udpproto != IP_UDPPROTO)
252 			continue;
253 
254 		h->ttl = 0;
255 		len = nhgets(h->udplen);
256 		hnputs(h->udpplen, len);
257 
258 		if(CHECKSUM && nhgets(h->udpcksum)) {
259 			csm = ptcl_csum(&h->ttl, len+UDP_PHDRSIZE);
260 			if(csm != 0) {
261 				print("udp chksum error csum #%4lux len %d\n", csm, n);
262 				break;
263 			}
264 		}
265 
266 		if(a->port != 0 && nhgets(h->udpsport) != a->port)
267 			continue;
268 
269 		addr = nhgetl(h->udpsrc);
270 		if(a->ip != Bcastip && addr != a->ip)
271 			continue;
272 
273 		len -= UDP_HDRSIZE-UDP_PHDRSIZE;
274 		if(len > dlen) {
275 			print("udp: packet too big\n");
276 			continue;
277 		}
278 
279 		memmove(data, h->udpcksum+sizeof(h->udpcksum), len);
280 		a->ip = addr;
281 		a->port = nhgets(h->udpsport);
282 		memmove(a->ea, pkt.s, sizeof(a->ea));
283 
284 		rxactive = 1;
285 		return len;
286 	}
287 
288 	return 0;
289 }
290 
291 static int tftpblockno;
292 
293 static int
294 tftpopen(int ctlrno, Netaddr *a, char *name, Tftp *tftp)
295 {
296 	int i, len, rlen, oport;
297 	char buf[Segsize+2];
298 
299 	buf[0] = 0;
300 	buf[1] = Tftp_READ;
301 	len = sprint(buf+2, "%s", name) + 2;
302 	len += sprint(buf+len+1, "octet") + 2;
303 
304 	oport = a->port;
305 	for(i = 0; i < 5; i++){
306 		a->port = oport;
307 		udpsend(ctlrno, a, buf, len);
308 		a->port = 0;
309 		if((rlen = udprecv(ctlrno, a, tftp, sizeof(Tftp))) < sizeof(tftp->header))
310 			continue;
311 
312 		switch((tftp->header[0]<<8)|tftp->header[1]){
313 
314 		case Tftp_ERROR:
315 			print("tftpopen: error (%d): %s\n",
316 				(tftp->header[2]<<8)|tftp->header[3], tftp->data);
317 			return -1;
318 
319 		case Tftp_DATA:
320 			tftpblockno = 1;
321 			len = (tftp->header[2]<<8)|tftp->header[3];
322 			if(len != tftpblockno){
323 				print("tftpopen: block error: %d\n", len);
324 				nak(ctlrno, a, 1, "block error", 0);
325 				return -1;
326 			}
327 			return rlen-sizeof(tftp->header);
328 		}
329 	}
330 
331 	print("tftpopen: failed to connect to server\n");
332 	return -1;
333 }
334 
335 static int
336 tftpread(int ctlrno, Netaddr *a, Tftp *tftp, int dlen)
337 {
338 	int blockno, len, retry;
339 	uchar buf[4];
340 
341 	buf[0] = 0;
342 	buf[1] = Tftp_ACK;
343 	buf[2] = tftpblockno>>8;
344 	buf[3] = tftpblockno;
345 	tftpblockno++;
346 
347 	dlen += sizeof(tftp->header);
348 
349 	retry = 0;
350 buggery:
351 	udpsend(ctlrno, a, buf, sizeof(buf));
352 
353 	if((len = udprecv(ctlrno, a, tftp, dlen)) < dlen){
354 		print("tftpread: %d != %d\n", len, dlen);
355 		nak(ctlrno, a, 2, "short read", 0);
356 		if(retry++ < 5)
357 			goto buggery;
358 		return -1;
359 	}
360 
361 	blockno = (tftp->header[2]<<8)|tftp->header[3];
362 	if(blockno != tftpblockno){
363 		print("?");
364 
365 		if(blockno == tftpblockno-1 && retry++ < 8)
366 			goto buggery;
367 		print("tftpread: block error: %d, expected %d\n", blockno, tftpblockno);
368 		nak(ctlrno, a, 1, "block error", 0);
369 
370 		return -1;
371 	}
372 
373 	return len-sizeof(tftp->header);
374 }
375 
376 int
377 bootp(int ctlrno, char *file)
378 {
379 	Bootp req, rep;
380 	int i, dlen, segsize, text, data, bss, total;
381 	uchar *ea, *addr, *p;
382 	ulong entry;
383 	Exec *exec;
384 	char name[128], *filename, *sysname;
385 
386 	if((ea = etheraddr(ctlrno)) == 0){
387 		print("invalid ctlrno %d\n", ctlrno);
388 		return -1;
389 	}
390 
391 	filename = 0;
392 	sysname = 0;
393 	if(file && *file){
394 		strcpy(name, file);
395 		if(filename = strchr(name, ':')){
396 			if(filename != name && *(filename-1) != '\\'){
397 				sysname = name;
398 				*filename++ = 0;
399 			}
400 		}
401 		else
402 			filename = name;
403 	}
404 
405 
406 	memset(&req, 0, sizeof(req));
407 	req.op = Bootrequest;
408 	req.htype = 1;			/* ethernet */
409 	req.hlen = Eaddrlen;		/* ethernet */
410 	memmove(req.chaddr, ea, Eaddrlen);
411 
412 	myaddr.ip = 0;
413 	myaddr.port = BPportsrc;
414 	memmove(myaddr.ea, ea, Eaddrlen);
415 
416 	for(i = 0; i < 10; i++) {
417 		server.ip = Bcastip;
418 		server.port = BPportdst;
419 		memmove(server.ea, broadcast, sizeof(server.ea));
420 		udpsend(ctlrno, &server, &req, sizeof(req));
421 		if(udprecv(ctlrno, &server, &rep, sizeof(rep)) <= 0)
422 			continue;
423 		if(memcmp(req.chaddr, rep.chaddr, Eaddrlen))
424 			continue;
425 		if(rep.htype != 1 || rep.hlen != Eaddrlen)
426 			continue;
427 		if(sysname == 0 || strcmp(sysname, rep.sname) == 0)
428 			break;
429 	}
430 	if(i >= 10) {
431 		print("bootp timed out\n");
432 		return -1;
433 	}
434 
435 	if(filename == 0 || *filename == 0)
436 		filename = rep.file;
437 
438 	if(rep.sname[0] != '\0')
439 		 print("%s ", rep.sname);
440 	print("(%d.%d.%d.%d!%d): %s\n",
441 		rep.siaddr[0],
442 		rep.siaddr[1],
443 		rep.siaddr[2],
444 		rep.siaddr[3],
445 		server.port,
446 		filename);uartwait();
447 
448 	myaddr.ip = nhgetl(rep.yiaddr);
449 	myaddr.port = tftpport++;
450 	server.ip = nhgetl(rep.siaddr);
451 	server.port = TFTPport;
452 
453 	if((dlen = tftpopen(ctlrno, &server, filename, &tftpb)) < 0)
454 		return -1;
455 	exec = (Exec*)(tftpb.data);
456 	if(dlen < sizeof(Exec) || GLLONG(exec->magic) != Q_MAGIC){
457 		nak(ctlrno, &server, 0, "bad magic number", 1);
458 		return -1;
459 	}
460 	text = GLLONG(exec->text);
461 	data = GLLONG(exec->data);
462 	bss = GLLONG(exec->bss);
463 	total = text+data+bss;
464 	entry = GLLONG(exec->entry);
465 print("load@%8.8lux: ", PADDR(entry));uartwait();
466 	print("%d", text);
467 
468 	addr = (uchar*)PADDR(entry);
469 	p = tftpb.data+sizeof(Exec);
470 	dlen -= sizeof(Exec);
471 	segsize = text;
472 	for(;;){
473 		if(dlen == 0){
474 			if((dlen = tftpread(ctlrno, &server, &tftpb, sizeof(tftpb.data))) < 0)
475 				return -1;
476 			p = tftpb.data;
477 		}
478 		if(segsize <= dlen)
479 			i = segsize;
480 		else
481 			i = dlen;
482 		memmove(addr, p, i);
483 
484 		addr += i;
485 		p += i;
486 		segsize -= i;
487 		dlen -= i;
488 
489 		if(segsize <= 0){
490 			if(data == 0)
491 				break;
492 			print("+%d", data);
493 			segsize = data;
494 			data = 0;
495 			addr = (uchar*)PGROUND((ulong)addr);
496 		}
497 	}
498 	nak(ctlrno, &server, 3, "ok", 0);		/* tftpclose */
499 	print("+%d=%d\n", bss, total);
500 	print("entry: 0x%lux\n", entry);
501 	uartwait();
502 	scc2stop();
503 	splhi();
504 	(*(void(*)(void))(PADDR(entry)))();
505 
506 	return 0;
507 }
508