1 #include <u.h>
2 #include <libc.h>
3 #include <ip.h>
4 #include <thread.h>
5 #include "netbios.h"
6
7 static struct {
8 int thread;
9 QLock;
10 char adir[NETPATHLEN];
11 int acfd;
12 char ldir[NETPATHLEN];
13 int lcfd;
14 } tcp = { -1 };
15
16 typedef struct Session Session;
17
18 enum { NeedSessionRequest, Connected, Dead };
19
20 struct Session {
21 NbSession;
22 int thread;
23 Session *next;
24 int state;
25 NBSSWRITEFN *write;
26 };
27
28 static struct {
29 QLock;
30 Session *head;
31 } sessions;
32
33 typedef struct Listen Listen;
34
35 struct Listen {
36 NbName to;
37 NbName from;
38 int (*accept)(void *magic, NbSession *s, NBSSWRITEFN **writep);
39 void *magic;
40 Listen *next;
41 };
42
43 static struct {
44 QLock;
45 Listen *head;
46 } listens;
47
48 static void
deletesession(Session * s)49 deletesession(Session *s)
50 {
51 Session **sp;
52 close(s->fd);
53 qlock(&sessions);
54 for (sp = &sessions.head; *sp && *sp != s; sp = &(*sp)->next)
55 ;
56 if (*sp)
57 *sp = s->next;
58 qunlock(&sessions);
59 free(s);
60 }
61
62 static void
tcpreader(void * a)63 tcpreader(void *a)
64 {
65 Session *s = a;
66 uchar *buf;
67 int buflen = 0x1ffff + 4;
68 buf = nbemalloc(buflen);
69 for (;;) {
70 int n;
71 uchar flags;
72 ushort length;
73
74 n = readn(s->fd, buf, 4);
75 if (n != 4) {
76 die:
77 free(buf);
78 if (s->state == Connected)
79 (*s->write)(s, nil, -1);
80 deletesession(s);
81 return;
82 }
83 flags = buf[1];
84 length = nhgets(buf + 2) | ((flags & 1) << 16);
85 n = readn(s->fd, buf + 4, length);
86 if (n != length)
87 goto die;
88 if (flags & 0xfe) {
89 print("nbss: invalid flags field 0x%.2ux\n", flags);
90 goto die;
91 }
92 switch (buf[0]) {
93 case 0: /* session message */
94 if (s->state != Connected && s->state != Dead) {
95 print("nbss: unexpected session message\n");
96 goto die;
97 }
98 if (s->state == Connected) {
99 if ((*s->write)(s, buf + 4, length) != 0) {
100 s->state = Dead;
101 goto die;
102 }
103 }
104 break;
105 case 0x81: /* session request */ {
106 uchar *p, *ep;
107 Listen *l;
108 int k;
109 int called_found;
110 uchar error_code;
111
112 if (s->state == Connected) {
113 print("nbss: unexpected session request\n");
114 goto die;
115 }
116 p = buf + 4;
117 ep = p + length;
118 k = nbnamedecode(p, p, ep, s->to);
119 if (k == 0) {
120 print("nbss: malformed called name in session request\n");
121 goto die;
122 }
123 p += k;
124 k = nbnamedecode(p, p, ep, s->from);
125 if (k == 0) {
126 print("nbss: malformed calling name in session request\n");
127 goto die;
128 }
129 /*
130 p += k;
131 if (p != ep) {
132 print("nbss: extra data at end of session request\n");
133 goto die;
134 }
135 */
136 called_found = 0;
137 //print("nbss: called %B calling %B\n", s->to, s->from);
138 qlock(&listens);
139 for (l = listens.head; l; l = l->next)
140 if (nbnameequal(l->to, s->to)) {
141 called_found = 1;
142 if (nbnameequal(l->from, s->from))
143 break;
144 }
145 if (l == nil) {
146 qunlock(&listens);
147 error_code = called_found ? 0x81 : 0x80;
148 replydie:
149 buf[0] = 0x83;
150 buf[1] = 0;
151 hnputs(buf + 2, 1);
152 buf[4] = error_code;
153 write(s->fd, buf, 5);
154 goto die;
155 }
156 if (!(*l->accept)(l->magic, s, &s->write)) {
157 qunlock(&listens);
158 error_code = 0x83;
159 goto replydie;
160 }
161 buf[0] = 0x82;
162 buf[1] = 0;
163 hnputs(buf + 2, 0);
164 if (write(s->fd, buf, 4) != 4) {
165 qunlock(&listens);
166 goto die;
167 }
168 s->state = Connected;
169 qunlock(&listens);
170 break;
171 }
172 case 0x85: /* keep awake */
173 break;
174 default:
175 print("nbss: opcode 0x%.2ux unexpected\n", buf[0]);
176 goto die;
177 }
178 }
179 }
180
181 static NbSession *
createsession(int fd)182 createsession(int fd)
183 {
184 Session *s;
185 s = nbemalloc(sizeof(Session));
186 s->fd = fd;
187 s->state = NeedSessionRequest;
188 qlock(&sessions);
189 s->thread = procrfork(tcpreader, s, 32768, RFNAMEG);
190 if (s->thread < 0) {
191 qunlock(&sessions);
192 free(s);
193 return nil;
194 }
195 s->next = sessions.head;
196 sessions.head = s;
197 qunlock(&sessions);
198 return s;
199 }
200
201 static void
tcplistener(void *)202 tcplistener(void *)
203 {
204 for (;;) {
205 int dfd;
206 char ldir[NETPATHLEN];
207 int lcfd;
208 //print("tcplistener: listening\n");
209 lcfd = listen(tcp.adir, ldir);
210 //print("tcplistener: contact\n");
211 if (lcfd < 0) {
212 die:
213 qlock(&tcp);
214 close(tcp.acfd);
215 tcp.thread = -1;
216 qunlock(&tcp);
217 return;
218 }
219 dfd = accept(lcfd, ldir);
220 close(lcfd);
221 if (dfd < 0)
222 goto die;
223 if (createsession(dfd) == nil)
224 close(dfd);
225 }
226 }
227
228 int
nbsslisten(NbName to,NbName from,int (* accept)(void * magic,NbSession * s,NBSSWRITEFN ** writep),void * magic)229 nbsslisten(NbName to, NbName from,int (*accept)(void *magic, NbSession *s, NBSSWRITEFN **writep), void *magic)
230 {
231 Listen *l;
232 qlock(&tcp);
233 if (tcp.thread < 0) {
234 fmtinstall('B', nbnamefmt);
235 tcp.acfd = announce("tcp!*!netbios", tcp.adir);
236 if (tcp.acfd < 0) {
237 print("nbsslisten: can't announce: %r\n");
238 qunlock(&tcp);
239 return -1;
240 }
241 tcp.thread = proccreate(tcplistener, nil, 16384);
242 }
243 qunlock(&tcp);
244 l = nbemalloc(sizeof(Listen));
245 nbnamecpy(l->to, to);
246 nbnamecpy(l->from, from);
247 l->accept = accept;
248 l->magic = magic;
249 qlock(&listens);
250 l->next = listens.head;
251 listens.head = l;
252 qunlock(&listens);
253 return 0;
254 }
255
256 void
nbssfree(NbSession * s)257 nbssfree(NbSession *s)
258 {
259 deletesession((Session *)s);
260 }
261
262 int
nbssgatherwrite(NbSession * s,NbScatterGather * a)263 nbssgatherwrite(NbSession *s, NbScatterGather *a)
264 {
265 uchar hdr[4];
266 NbScatterGather *ap;
267 long l = 0;
268 for (ap = a; ap->p; ap++)
269 l += ap->l;
270 //print("nbssgatherwrite %ld bytes\n", l);
271 hnputl(hdr, l);
272 //nbdumpdata(hdr, sizeof(hdr));
273 if (write(s->fd, hdr, sizeof(hdr)) != sizeof(hdr))
274 return -1;
275 for (ap = a; ap->p; ap++) {
276 //nbdumpdata(ap->p, ap->l);
277 if (write(s->fd, ap->p, ap->l) != ap->l)
278 return -1;
279 }
280 return 0;
281 }
282
283 NbSession *
nbssconnect(NbName to,NbName from)284 nbssconnect(NbName to, NbName from)
285 {
286 Session *s;
287 uchar ipaddr[IPaddrlen];
288 char dialaddress[100];
289 char dir[NETPATHLEN];
290 uchar msg[576];
291 int fd;
292 long o;
293 uchar flags;
294 long length;
295
296 if (!nbnameresolve(to, ipaddr))
297 return nil;
298 fmtinstall('I', eipfmt);
299 snprint(dialaddress, sizeof(dialaddress), "tcp!%I!netbios", ipaddr);
300 fd = dial(dialaddress, nil, dir, nil);
301 if (fd < 0)
302 return nil;
303 msg[0] = 0x81;
304 msg[1] = 0;
305 o = 4;
306 o += nbnameencode(msg + o, msg + sizeof(msg) - o, to);
307 o += nbnameencode(msg + o, msg + sizeof(msg) - o, from);
308 hnputs(msg + 2, o - 4);
309 if (write(fd, msg, o) != o) {
310 close(fd);
311 return nil;
312 }
313 if (readn(fd, msg, 4) != 4) {
314 close(fd);
315 return nil;
316 }
317 flags = msg[1];
318 length = nhgets(msg + 2) | ((flags & 1) << 16);
319 switch (msg[0]) {
320 default:
321 close(fd);
322 werrstr("unexpected session message code 0x%.2ux", msg[0]);
323 return nil;
324 case 0x82:
325 if (length != 0) {
326 close(fd);
327 werrstr("length not 0 in positive session response");
328 return nil;
329 }
330 break;
331 case 0x83:
332 if (length != 1) {
333 close(fd);
334 werrstr("length not 1 in negative session response");
335 return nil;
336 }
337 if (readn(fd, msg + 4, 1) != 1) {
338 close(fd);
339 return nil;
340 }
341 close(fd);
342 werrstr("negative session response 0x%.2ux", msg[4]);
343 return nil;
344 }
345 s = nbemalloc(sizeof(Session));
346 s->fd = fd;
347 s->state = Connected;
348 qlock(&sessions);
349 s->next = sessions.head;
350 sessions.head = s;
351 qunlock(&sessions);
352 return s;
353 }
354
355 long
nbssscatterread(NbSession * nbs,NbScatterGather * a)356 nbssscatterread(NbSession *nbs, NbScatterGather *a)
357 {
358 uchar hdr[4];
359 uchar flags;
360 long length, total;
361 NbScatterGather *ap;
362 Session *s = (Session *)nbs;
363
364 long l = 0;
365 for (ap = a; ap->p; ap++)
366 l += ap->l;
367 //print("nbssscatterread %ld bytes\n", l);
368 again:
369 if (readn(s->fd, hdr, 4) != 4) {
370 dead:
371 s->state = Dead;
372 return -1;
373 }
374 flags = hdr[1];
375 length = nhgets(hdr + 2) | ((flags & 1) << 16);
376 //print("%.2ux: %d\n", hdr[0], length);
377 switch (hdr[0]) {
378 case 0x85:
379 if (length != 0) {
380 werrstr("length in keepalive not 0");
381 goto dead;
382 }
383 goto again;
384 case 0x00:
385 break;
386 default:
387 werrstr("unexpected session message code 0x%.2ux", hdr[0]);
388 goto dead;
389 }
390 if (length > l) {
391 werrstr("message too big (%ld)", length);
392 goto dead;
393 }
394 total = length;
395 for (ap = a; length && ap->p; ap++) {
396 long thistime;
397 long n;
398 thistime = length;
399 if (thistime > ap->l)
400 thistime = ap->l;
401 //print("reading %d\n", length);
402 n = readn(s->fd, ap->p, thistime);
403 if (n != thistime)
404 goto dead;
405 length -= thistime;
406 }
407 return total;
408 }
409
410 int
nbsswrite(NbSession * s,void * buf,long maxlen)411 nbsswrite(NbSession *s, void *buf, long maxlen)
412 {
413 NbScatterGather a[2];
414 a[0].l = maxlen;
415 a[0].p = buf;
416 a[1].p = nil;
417 return nbssgatherwrite(s, a);
418 }
419
420 long
nbssread(NbSession * s,void * buf,long maxlen)421 nbssread(NbSession *s, void *buf, long maxlen)
422 {
423 NbScatterGather a[2];
424 a[0].l = maxlen;
425 a[0].p = buf;
426 a[1].p = nil;
427 return nbssscatterread(s, a);
428 }
429