xref: /plan9/sys/src/cmd/ndb/dnstcp.c (revision e02f7f02bd837384880240d3ead0b1c0930eacbe)
1 /*
2  * dnstcp - serve dns via tcp
3  */
4 #include <u.h>
5 #include <libc.h>
6 #include <ip.h>
7 #include "dns.h"
8 
9 Cfg cfg;
10 
11 char	*caller = "";
12 char	*dbfile;
13 int	debug;
14 uchar	ipaddr[IPaddrlen];	/* my ip address */
15 char	*logfile = "dns";
16 int	maxage = 60*60;
17 char	mntpt[Maxpath];
18 int	needrefresh;
19 ulong	now;
20 vlong	nowns;
21 int	testing;
22 int	traceactivity;
23 char	*zonerefreshprogram;
24 
25 static int	readmsg(int, uchar*, int);
26 static void	reply(int, DNSmsg*, Request*);
27 static void	dnzone(DNSmsg*, DNSmsg*, Request*);
28 static void	getcaller(char*);
29 static void	refreshmain(char*);
30 
31 void
usage(void)32 usage(void)
33 {
34 	fprint(2, "usage: %s [-rR] [-f ndb-file] [-x netmtpt] [conndir]\n", argv0);
35 	exits("usage");
36 }
37 
38 void
main(int argc,char * argv[])39 main(int argc, char *argv[])
40 {
41 	volatile int len, rcode;
42 	volatile char tname[32];
43 	char *volatile err, *volatile ext = "";
44 	volatile uchar buf[64*1024], callip[IPaddrlen];
45 	volatile DNSmsg reqmsg, repmsg;
46 	volatile Request req;
47 
48 	alarm(2*60*1000);
49 	cfg.cachedb = 1;
50 	ARGBEGIN{
51 	case 'd':
52 		debug++;
53 		break;
54 	case 'f':
55 		dbfile = EARGF(usage());
56 		break;
57 	case 'r':
58 		cfg.resolver = 1;
59 		break;
60 	case 'R':
61 		norecursion = 1;
62 		break;
63 	case 'x':
64 		ext = EARGF(usage());
65 		break;
66 	default:
67 		usage();
68 		break;
69 	}ARGEND
70 
71 	if(debug < 2)
72 		debug = 0;
73 
74 	if(argc > 0)
75 		getcaller(argv[0]);
76 
77 	cfg.inside = 1;
78 	dninit();
79 
80 	snprint(mntpt, sizeof mntpt, "/net%s", ext);
81 	if(myipaddr(ipaddr, mntpt) < 0)
82 		sysfatal("can't read my ip address");
83 	dnslog("dnstcp call from %s to %I", caller, ipaddr);
84 	memset(callip, 0, sizeof callip);
85 	parseip(callip, caller);
86 
87 	db2cache(1);
88 
89 	memset(&req, 0, sizeof req);
90 	setjmp(req.mret);
91 	req.isslave = 0;
92 	procsetname("main loop");
93 
94 	/* loop on requests */
95 	for(;; putactivity(0)){
96 		now = time(nil);
97 		memset(&repmsg, 0, sizeof repmsg);
98 		len = readmsg(0, buf, sizeof buf);
99 		if(len <= 0)
100 			break;
101 
102 		getactivity(&req, 0);
103 		req.aborttime = timems() + S2MS(15*Min);
104 		rcode = 0;
105 		memset(&reqmsg, 0, sizeof reqmsg);
106 		err = convM2DNS(buf, len, &reqmsg, &rcode);
107 		if(err){
108 			dnslog("server: input error: %s from %s", err, caller);
109 			free(err);
110 			break;
111 		}
112 		if (rcode == 0)
113 			if(reqmsg.qdcount < 1){
114 				dnslog("server: no questions from %s", caller);
115 				break;
116 			} else if(reqmsg.flags & Fresp){
117 				dnslog("server: reply not request from %s",
118 					caller);
119 				break;
120 			} else if((reqmsg.flags & Omask) != Oquery){
121 				dnslog("server: op %d from %s",
122 					reqmsg.flags & Omask, caller);
123 				break;
124 			}
125 		if(debug)
126 			dnslog("[%d] %d: serve (%s) %d %s %s",
127 				getpid(), req.id, caller,
128 				reqmsg.id, reqmsg.qd->owner->name,
129 				rrname(reqmsg.qd->type, tname, sizeof tname));
130 
131 		/* loop through each question */
132 		while(reqmsg.qd)
133 			if(reqmsg.qd->type == Taxfr)
134 				dnzone(&reqmsg, &repmsg, &req);
135 			else {
136 				dnserver(&reqmsg, &repmsg, &req, callip, rcode);
137 				reply(1, &repmsg, &req);
138 				rrfreelist(repmsg.qd);
139 				rrfreelist(repmsg.an);
140 				rrfreelist(repmsg.ns);
141 				rrfreelist(repmsg.ar);
142 			}
143 		rrfreelist(reqmsg.qd);		/* qd will be nil */
144 		rrfreelist(reqmsg.an);
145 		rrfreelist(reqmsg.ns);
146 		rrfreelist(reqmsg.ar);
147 
148 		if(req.isslave){
149 			putactivity(0);
150 			_exits(0);
151 		}
152 	}
153 	refreshmain(mntpt);
154 }
155 
156 static int
readmsg(int fd,uchar * buf,int max)157 readmsg(int fd, uchar *buf, int max)
158 {
159 	int n;
160 	uchar x[2];
161 
162 	if(readn(fd, x, 2) != 2)
163 		return -1;
164 	n = x[0]<<8 | x[1];
165 	if(n > max)
166 		return -1;
167 	if(readn(fd, buf, n) != n)
168 		return -1;
169 	return n;
170 }
171 
172 static void
reply(int fd,DNSmsg * rep,Request * req)173 reply(int fd, DNSmsg *rep, Request *req)
174 {
175 	int len, rv;
176 	char tname[32];
177 	uchar buf[64*1024];
178 	RR *rp;
179 
180 	if(debug){
181 		dnslog("%d: reply (%s) %s %s %ux",
182 			req->id, caller,
183 			rep->qd->owner->name,
184 			rrname(rep->qd->type, tname, sizeof tname),
185 			rep->flags);
186 		for(rp = rep->an; rp; rp = rp->next)
187 			dnslog("an %R", rp);
188 		for(rp = rep->ns; rp; rp = rp->next)
189 			dnslog("ns %R", rp);
190 		for(rp = rep->ar; rp; rp = rp->next)
191 			dnslog("ar %R", rp);
192 	}
193 
194 
195 	len = convDNS2M(rep, buf+2, sizeof(buf) - 2);
196 	buf[0] = len>>8;
197 	buf[1] = len;
198 	rv = write(fd, buf, len+2);
199 	if(rv != len+2){
200 		dnslog("[%d] sending reply: %d instead of %d", getpid(), rv,
201 			len+2);
202 		exits(0);
203 	}
204 }
205 
206 /*
207  *  Hash table for domain names.  The hash is based only on the
208  *  first element of the domain name.
209  */
210 extern DN	*ht[HTLEN];
211 
212 static int
numelem(char * name)213 numelem(char *name)
214 {
215 	int i;
216 
217 	i = 1;
218 	for(; *name; name++)
219 		if(*name == '.')
220 			i++;
221 	return i;
222 }
223 
224 int
inzone(DN * dp,char * name,int namelen,int depth)225 inzone(DN *dp, char *name, int namelen, int depth)
226 {
227 	int n;
228 
229 	if(dp->name == nil)
230 		return 0;
231 	if(numelem(dp->name) != depth)
232 		return 0;
233 	n = strlen(dp->name);
234 	if(n < namelen)
235 		return 0;
236 	if(strcmp(name, dp->name + n - namelen) != 0)
237 		return 0;
238 	if(n > namelen && dp->name[n - namelen - 1] != '.')
239 		return 0;
240 	return 1;
241 }
242 
243 static void
dnzone(DNSmsg * reqp,DNSmsg * repp,Request * req)244 dnzone(DNSmsg *reqp, DNSmsg *repp, Request *req)
245 {
246 	DN *dp, *ndp;
247 	RR r, *rp;
248 	int h, depth, found, nlen;
249 
250 	memset(repp, 0, sizeof(*repp));
251 	repp->id = reqp->id;
252 	repp->qd = reqp->qd;
253 	reqp->qd = reqp->qd->next;
254 	repp->qd->next = 0;
255 	repp->flags = Fauth | Fresp | Oquery;
256 	if(!norecursion)
257 		repp->flags |= Fcanrec;
258 	dp = repp->qd->owner;
259 
260 	/* send the soa */
261 	repp->an = rrlookup(dp, Tsoa, NOneg);
262 	reply(1, repp, req);
263 	if(repp->an == 0)
264 		goto out;
265 	rrfreelist(repp->an);
266 	repp->an = nil;
267 
268 	nlen = strlen(dp->name);
269 
270 	/* construct a breadth-first search of the name space (hard with a hash) */
271 	repp->an = &r;
272 	for(depth = numelem(dp->name); ; depth++){
273 		found = 0;
274 		for(h = 0; h < HTLEN; h++)
275 			for(ndp = ht[h]; ndp; ndp = ndp->next)
276 				if(inzone(ndp, dp->name, nlen, depth)){
277 					for(rp = ndp->rr; rp; rp = rp->next){
278 						/*
279 						 * there shouldn't be negatives,
280 						 * but just in case.
281 						 * don't send any soa's,
282 						 * ns's are enough.
283 						 */
284 						if (rp->negative ||
285 						    rp->type == Tsoa)
286 							continue;
287 						r = *rp;
288 						r.next = 0;
289 						reply(1, repp, req);
290 					}
291 					found = 1;
292 				}
293 		if(!found)
294 			break;
295 	}
296 
297 	/* resend the soa */
298 	repp->an = rrlookup(dp, Tsoa, NOneg);
299 	reply(1, repp, req);
300 	rrfreelist(repp->an);
301 	repp->an = nil;
302 out:
303 	rrfree(repp->qd);
304 	repp->qd = nil;
305 }
306 
307 static void
getcaller(char * dir)308 getcaller(char *dir)
309 {
310 	int fd, n;
311 	static char remote[128];
312 
313 	snprint(remote, sizeof(remote), "%s/remote", dir);
314 	fd = open(remote, OREAD);
315 	if(fd < 0)
316 		return;
317 	n = read(fd, remote, sizeof remote - 1);
318 	close(fd);
319 	if(n <= 0)
320 		return;
321 	if(remote[n-1] == '\n')
322 		n--;
323 	remote[n] = 0;
324 	caller = remote;
325 }
326 
327 static void
refreshmain(char * net)328 refreshmain(char *net)
329 {
330 	int fd;
331 	char file[128];
332 
333 	snprint(file, sizeof(file), "%s/dns", net);
334 	if(debug)
335 		dnslog("refreshing %s", file);
336 	fd = open(file, ORDWR);
337 	if(fd < 0)
338 		dnslog("can't refresh %s", file);
339 	else {
340 		fprint(fd, "refresh");
341 		close(fd);
342 	}
343 }
344 
345 /*
346  *  the following varies between dnsdebug and dns
347  */
348 void
logreply(int id,uchar * addr,DNSmsg * mp)349 logreply(int id, uchar *addr, DNSmsg *mp)
350 {
351 	RR *rp;
352 
353 	dnslog("%d: rcvd %I flags:%s%s%s%s%s", id, addr,
354 		mp->flags & Fauth? " auth": "",
355 		mp->flags & Ftrunc? " trunc": "",
356 		mp->flags & Frecurse? " rd": "",
357 		mp->flags & Fcanrec? " ra": "",
358 		(mp->flags & (Fauth|Rmask)) == (Fauth|Rname)? " nx": "");
359 	for(rp = mp->qd; rp != nil; rp = rp->next)
360 		dnslog("%d: rcvd %I qd %s", id, addr, rp->owner->name);
361 	for(rp = mp->an; rp != nil; rp = rp->next)
362 		dnslog("%d: rcvd %I an %R", id, addr, rp);
363 	for(rp = mp->ns; rp != nil; rp = rp->next)
364 		dnslog("%d: rcvd %I ns %R", id, addr, rp);
365 	for(rp = mp->ar; rp != nil; rp = rp->next)
366 		dnslog("%d: rcvd %I ar %R", id, addr, rp);
367 }
368 
369 void
logsend(int id,int subid,uchar * addr,char * sname,char * rname,int type)370 logsend(int id, int subid, uchar *addr, char *sname, char *rname, int type)
371 {
372 	char buf[12];
373 
374 	dnslog("%d.%d: sending to %I/%s %s %s",
375 		id, subid, addr, sname, rname, rrname(type, buf, sizeof buf));
376 }
377 
378 RR*
getdnsservers(int class)379 getdnsservers(int class)
380 {
381 	return dnsservers(class);
382 }
383