xref: /plan9/sys/src/cmd/aquarela/nbss.c (revision 8ccd4a6360d974db7bd7bbd4f37e7018419ea908)
1 #include <u.h>
2 #include <libc.h>
3 #include <ip.h>
4 #include <thread.h>
5 #include "netbios.h"
6 
7 static struct {
8 	int thread;
9 	QLock;
10 	char adir[NETPATHLEN];
11 	int acfd;
12 	char ldir[NETPATHLEN];
13 	int lcfd;
14 } tcp = { -1 };
15 
16 typedef struct Session Session;
17 
18 enum { NeedSessionRequest, Connected, Dead };
19 
20 struct Session {
21 	NbSession;
22 	int thread;
23 	Session *next;
24 	int state;
25 	NBSSWRITEFN *write;
26 };
27 
28 static struct {
29 	QLock;
30 	Session *head;
31 } sessions;
32 
33 typedef struct Listen Listen;
34 
35 struct Listen {
36 	NbName to;
37 	NbName from;
38 	int (*accept)(void *magic, NbSession *s, NBSSWRITEFN **writep);
39 	void *magic;
40 	Listen *next;
41 };
42 
43 static struct {
44 	QLock;
45 	Listen *head;
46 } listens;
47 
48 static void
deletesession(Session * s)49 deletesession(Session *s)
50 {
51 	Session **sp;
52 	close(s->fd);
53 	qlock(&sessions);
54 	for (sp = &sessions.head; *sp && *sp != s; sp = &(*sp)->next)
55 		;
56 	if (*sp)
57 		*sp = s->next;
58 	qunlock(&sessions);
59 	free(s);
60 }
61 
62 static void
tcpreader(void * a)63 tcpreader(void *a)
64 {
65 	Session *s = a;
66 	uchar *buf;
67 	int buflen = 0x1ffff + 4;
68 	buf = nbemalloc(buflen);
69 	for (;;) {
70 		int n;
71 		uchar flags;
72 		ushort length;
73 
74 		n = readn(s->fd, buf, 4);
75 		if (n != 4) {
76 		die:
77 			free(buf);
78 			if (s->state == Connected)
79 				(*s->write)(s, nil, -1);
80 			deletesession(s);
81 			return;
82 		}
83 		flags = buf[1];
84 		length = nhgets(buf + 2) | ((flags & 1) << 16);
85 		n = readn(s->fd, buf + 4, length);
86 		if (n != length)
87 			goto die;
88 		if (flags & 0xfe) {
89 			print("nbss: invalid flags field 0x%.2ux\n", flags);
90 			goto die;
91 		}
92 		switch (buf[0]) {
93 		case 0: /* session message */
94 			if (s->state != Connected && s->state != Dead) {
95 				print("nbss: unexpected session message\n");
96 				goto die;
97 			}
98 			if (s->state == Connected) {
99 				if ((*s->write)(s, buf + 4, length) != 0) {
100 					s->state = Dead;
101 					goto die;
102 				}
103 			}
104 			break;
105 		case 0x81: /* session request */ {
106 			uchar *p, *ep;
107 			Listen *l;
108 			int k;
109 			int called_found;
110 			uchar error_code;
111 
112 			if (s->state == Connected) {
113 				print("nbss: unexpected session request\n");
114 				goto die;
115 			}
116 			p = buf + 4;
117 			ep = p + length;
118 			k = nbnamedecode(p, p, ep, s->to);
119 			if (k == 0) {
120 				print("nbss: malformed called name in session request\n");
121 				goto die;
122 			}
123 			p += k;
124 			k = nbnamedecode(p, p, ep, s->from);
125 			if (k == 0) {
126 				print("nbss: malformed calling name in session request\n");
127 				goto die;
128 			}
129 /*
130 			p += k;
131 			if (p != ep) {
132 				print("nbss: extra data at end of session request\n");
133 				goto die;
134 			}
135 */
136 			called_found = 0;
137 //print("nbss: called %B calling %B\n", s->to, s->from);
138 			qlock(&listens);
139 			for (l = listens.head; l; l = l->next)
140 				if (nbnameequal(l->to, s->to)) {
141 					called_found = 1;
142 					if (nbnameequal(l->from, s->from))
143 						break;
144 				}
145 			if (l == nil) {
146 				qunlock(&listens);
147 				error_code  = called_found ? 0x81 : 0x80;
148 			replydie:
149 				buf[0] = 0x83;
150 				buf[1] = 0;
151 				hnputs(buf + 2, 1);
152 				buf[4] = error_code;
153 				write(s->fd, buf, 5);
154 				goto die;
155 			}
156 			if (!(*l->accept)(l->magic, s, &s->write)) {
157 				qunlock(&listens);
158 				error_code = 0x83;
159 				goto replydie;
160 			}
161 			buf[0] = 0x82;
162 			buf[1] = 0;
163 			hnputs(buf + 2, 0);
164 			if (write(s->fd, buf, 4) != 4) {
165 				qunlock(&listens);
166 				goto die;
167 			}
168 			s->state = Connected;
169 			qunlock(&listens);
170 			break;
171 		}
172 		case 0x85: /* keep awake */
173 			break;
174 		default:
175 			print("nbss: opcode 0x%.2ux unexpected\n", buf[0]);
176 			goto die;
177 		}
178 	}
179 }
180 
181 static NbSession *
createsession(int fd)182 createsession(int fd)
183 {
184 	Session *s;
185 	s = nbemalloc(sizeof(Session));
186 	s->fd = fd;
187 	s->state = NeedSessionRequest;
188 	qlock(&sessions);
189 	s->thread = procrfork(tcpreader, s, 32768, RFNAMEG);
190 	if (s->thread < 0) {
191 		qunlock(&sessions);
192 		free(s);
193 		return nil;
194 	}
195 	s->next = sessions.head;
196 	sessions.head = s;
197 	qunlock(&sessions);
198 	return s;
199 }
200 
201 static void
tcplistener(void *)202 tcplistener(void *)
203 {
204 	for (;;) {
205 		int dfd;
206 		char ldir[NETPATHLEN];
207 		int lcfd;
208 //print("tcplistener: listening\n");
209 		lcfd = listen(tcp.adir, ldir);
210 //print("tcplistener: contact\n");
211 		if (lcfd < 0) {
212 		die:
213 			qlock(&tcp);
214 			close(tcp.acfd);
215 			tcp.thread = -1;
216 			qunlock(&tcp);
217 			return;
218 		}
219 		dfd = accept(lcfd, ldir);
220 		close(lcfd);
221 		if (dfd < 0)
222 			goto die;
223 		if (createsession(dfd) == nil)
224 			close(dfd);
225 	}
226 }
227 
228 int
nbsslisten(NbName to,NbName from,int (* accept)(void * magic,NbSession * s,NBSSWRITEFN ** writep),void * magic)229 nbsslisten(NbName to, NbName from,int (*accept)(void *magic, NbSession *s, NBSSWRITEFN **writep), void *magic)
230 {
231 	Listen *l;
232 	qlock(&tcp);
233 	if (tcp.thread < 0) {
234 		fmtinstall('B', nbnamefmt);
235 		tcp.acfd = announce("tcp!*!netbios", tcp.adir);
236 		if (tcp.acfd < 0) {
237 			print("nbsslisten: can't announce: %r\n");
238 			qunlock(&tcp);
239 			return -1;
240 		}
241 		tcp.thread = proccreate(tcplistener, nil, 16384);
242 	}
243 	qunlock(&tcp);
244 	l = nbemalloc(sizeof(Listen));
245 	nbnamecpy(l->to, to);
246 	nbnamecpy(l->from, from);
247 	l->accept = accept;
248 	l->magic = magic;
249 	qlock(&listens);
250 	l->next = listens.head;
251 	listens.head = l;
252 	qunlock(&listens);
253 	return 0;
254 }
255 
256 void
nbssfree(NbSession * s)257 nbssfree(NbSession *s)
258 {
259 	deletesession((Session *)s);
260 }
261 
262 int
nbssgatherwrite(NbSession * s,NbScatterGather * a)263 nbssgatherwrite(NbSession *s, NbScatterGather *a)
264 {
265 	uchar hdr[4];
266 	NbScatterGather *ap;
267 	long l = 0;
268 	for (ap = a; ap->p; ap++)
269 		l += ap->l;
270 //print("nbssgatherwrite %ld bytes\n", l);
271 	hnputl(hdr, l);
272 //nbdumpdata(hdr, sizeof(hdr));
273 	if (write(s->fd, hdr, sizeof(hdr)) != sizeof(hdr))
274 		return -1;
275 	for (ap = a; ap->p; ap++) {
276 //nbdumpdata(ap->p, ap->l);
277 		if (write(s->fd, ap->p, ap->l) != ap->l)
278 			return -1;
279 	}
280 	return 0;
281 }
282 
283 NbSession *
nbssconnect(NbName to,NbName from)284 nbssconnect(NbName to, NbName from)
285 {
286 	Session *s;
287 	uchar ipaddr[IPaddrlen];
288 	char dialaddress[100];
289 	char dir[NETPATHLEN];
290 	uchar msg[576];
291 	int fd;
292 	long o;
293 	uchar flags;
294 	long length;
295 
296 	if (!nbnameresolve(to, ipaddr))
297 		return nil;
298 	fmtinstall('I', eipfmt);
299 	snprint(dialaddress, sizeof(dialaddress), "tcp!%I!netbios", ipaddr);
300 	fd = dial(dialaddress, nil, dir, nil);
301 	if (fd < 0)
302 		return nil;
303 	msg[0] = 0x81;
304 	msg[1] = 0;
305 	o = 4;
306 	o += nbnameencode(msg + o, msg + sizeof(msg) - o, to);
307 	o += nbnameencode(msg + o, msg + sizeof(msg) - o, from);
308 	hnputs(msg + 2, o - 4);
309 	if (write(fd, msg, o) != o) {
310 		close(fd);
311 		return nil;
312 	}
313 	if (readn(fd, msg, 4) != 4) {
314 		close(fd);
315 		return nil;
316 	}
317 	flags = msg[1];
318 	length = nhgets(msg + 2) | ((flags & 1) << 16);
319 	switch (msg[0]) {
320 	default:
321 		close(fd);
322 		werrstr("unexpected session message code 0x%.2ux", msg[0]);
323 		return nil;
324 	case 0x82:
325 		if (length != 0) {
326 			close(fd);
327 			werrstr("length not 0 in positive session response");
328 			return nil;
329 		}
330 		break;
331 	case 0x83:
332 		if (length != 1) {
333 			close(fd);
334 			werrstr("length not 1 in negative session response");
335 			return nil;
336 		}
337 		if (readn(fd, msg + 4, 1) != 1) {
338 			close(fd);
339 			return nil;
340 		}
341 		close(fd);
342 		werrstr("negative session response 0x%.2ux", msg[4]);
343 		return nil;
344 	}
345 	s = nbemalloc(sizeof(Session));
346 	s->fd = fd;
347 	s->state = Connected;
348 	qlock(&sessions);
349 	s->next = sessions.head;
350 	sessions.head = s;
351 	qunlock(&sessions);
352 	return s;
353 }
354 
355 long
nbssscatterread(NbSession * nbs,NbScatterGather * a)356 nbssscatterread(NbSession *nbs, NbScatterGather *a)
357 {
358 	uchar hdr[4];
359 	uchar flags;
360 	long length, total;
361 	NbScatterGather *ap;
362 	Session *s = (Session *)nbs;
363 
364 	long l = 0;
365 	for (ap = a; ap->p; ap++)
366 		l += ap->l;
367 //print("nbssscatterread %ld bytes\n", l);
368 again:
369 	if (readn(s->fd, hdr, 4) != 4) {
370 	dead:
371 		s->state = Dead;
372 		return -1;
373 	}
374 	flags = hdr[1];
375 	length = nhgets(hdr + 2) | ((flags & 1) << 16);
376 //print("%.2ux: %d\n", hdr[0], length);
377 	switch (hdr[0]) {
378 	case 0x85:
379 		if (length != 0) {
380 			werrstr("length in keepalive not 0");
381 			goto dead;
382 		}
383 		goto again;
384 	case 0x00:
385 		break;
386 	default:
387 		werrstr("unexpected session message code 0x%.2ux", hdr[0]);
388 		goto dead;
389 	}
390 	if (length > l) {
391 		werrstr("message too big (%ld)", length);
392 		goto dead;
393 	}
394 	total = length;
395 	for (ap = a; length && ap->p; ap++) {
396 		long thistime;
397 		long n;
398 		thistime = length;
399 		if (thistime > ap->l)
400 			thistime = ap->l;
401 //print("reading %d\n", length);
402 		n = readn(s->fd, ap->p, thistime);
403 		if (n != thistime)
404 			goto dead;
405 		length -= thistime;
406 	}
407 	return total;
408 }
409 
410 int
nbsswrite(NbSession * s,void * buf,long maxlen)411 nbsswrite(NbSession *s, void *buf, long maxlen)
412 {
413 	NbScatterGather a[2];
414 	a[0].l = maxlen;
415 	a[0].p = buf;
416 	a[1].p = nil;
417 	return nbssgatherwrite(s, a);
418 }
419 
420 long
nbssread(NbSession * s,void * buf,long maxlen)421 nbssread(NbSession *s, void *buf, long maxlen)
422 {
423 	NbScatterGather a[2];
424 	a[0].l = maxlen;
425 	a[0].p = buf;
426 	a[1].p = nil;
427 	return nbssscatterread(s, a);
428 }
429