xref: /netbsd-src/external/mpl/bind/dist/bin/tests/system/cookie/ans9/ans.py (revision dd75ac5b443e967e26b4d18cc8cd5eb98512bfbf)
1# Copyright (C) Internet Systems Consortium, Inc. ("ISC")
2#
3# SPDX-License-Identifier: MPL-2.0
4#
5# This Source Code Form is subject to the terms of the Mozilla Public
6# License, v. 2.0.  If a copy of the MPL was not distributed with this
7# file, you can obtain one at https://mozilla.org/MPL/2.0/.
8#
9# See the COPYRIGHT file distributed with this work for additional
10# information regarding copyright ownership.
11
12from __future__ import print_function
13import os
14import sys
15import signal
16import socket
17import select
18from datetime import datetime, timedelta
19import time
20import functools
21
22import dns
23import dns.edns
24import dns.flags
25import dns.message
26import dns.query
27import dns.tsig
28import dns.tsigkeyring
29import dns.version
30
31from dns.edns import *
32from dns.name import *
33from dns.rcode import *
34from dns.rdataclass import *
35from dns.rdatatype import *
36from dns.tsig import *
37
38# Log query to file
39def logquery(type, qname):
40    with open("qlog", "a") as f:
41        f.write("%s %s\n", type, qname)
42
43
44# DNS 2.0 keyring specifies the algorithm
45try:
46    keyring = dns.tsigkeyring.from_text(
47        {
48            "foo": {"hmac-sha256", "aaaaaaaaaaaa"},
49            "fake": {"hmac-sha256", "aaaaaaaaaaaa"},
50        }
51    )
52except:
53    keyring = dns.tsigkeyring.from_text({"foo": "aaaaaaaaaaaa", "fake": "aaaaaaaaaaaa"})
54
55dopass2 = False
56
57############################################################################
58#
59# This server will serve valid and spoofed answers. A spoofed answer will
60# have the address 10.53.0.10 included.
61#
62# When receiving a query over UDP:
63#
64# A query to "nocookie"/A will result in a spoofed answer with no cookie set.
65# A query to "tcponly"/A will result in a spoofed answer with no cookie set.
66# A query to "withtsig"/A will result in two responses, the first is a spoofed
67# answer that is TSIG signed, the second is a valid answer with a cookie set.
68# A query to anything else will result in a valid answer with a cookie set.
69#
70# When receiving a query over TCP:
71#
72# A query to "nocookie"/A will result in a valid answer with no cookie set.
73# A query to anything else will result in a valid answer with a cookie set.
74#
75############################################################################
76def create_response(msg, tcp, first, ns10):
77    global dopass2
78    m = dns.message.from_wire(msg, keyring=keyring)
79    qname = m.question[0].name.to_text()
80    lqname = qname.lower()
81    labels = lqname.split(".")
82    rrtype = m.question[0].rdtype
83    typename = dns.rdatatype.to_text(rrtype)
84
85    with open("query.log", "a") as f:
86        f.write("%s %s\n" % (typename, qname))
87        print("%s %s" % (typename, qname), end=" ")
88
89    r = dns.message.make_response(m)
90    r.set_rcode(NOERROR)
91    if rrtype == A:
92        # exempt potential nameserver A records.
93        if labels[0] == "ns" and ns10:
94            r.answer.append(dns.rrset.from_text(qname, 1, IN, A, "10.53.0.10"))
95        else:
96            r.answer.append(dns.rrset.from_text(qname, 1, IN, A, "10.53.0.9"))
97        if not tcp and labels[0] == "nocookie":
98            r.answer.append(dns.rrset.from_text(qname, 1, IN, A, "10.53.0.10"))
99        if not tcp and labels[0] == "tcponly":
100            r.answer.append(dns.rrset.from_text(qname, 1, IN, A, "10.53.0.10"))
101        if first and not tcp and labels[0] == "withtsig":
102            r.answer.append(dns.rrset.from_text(qname, 1, IN, A, "10.53.0.10"))
103            dopass2 = True
104    elif rrtype == NS:
105        r.answer.append(dns.rrset.from_text(qname, 1, IN, NS, "."))
106    elif rrtype == SOA:
107        r.answer.append(dns.rrset.from_text(qname, 1, IN, SOA, ". . 0 0 0 0 0"))
108    else:
109        r.authority.append(dns.rrset.from_text(qname, 1, IN, SOA, ". . 0 0 0 0 0"))
110    # Add a server cookie to the response
111    if labels[0] != "nocookie":
112        for o in m.options:
113            if o.otype == 10:  # Use 10 instead of COOKIE
114                if first and labels[0] == "withtsig" and not tcp:
115                    r.use_tsig(
116                        keyring=keyring,
117                        keyname=dns.name.from_text("fake"),
118                        algorithm=HMAC_SHA256,
119                    )
120                elif labels[0] != "tcponly" or tcp:
121                    cookie = o
122                    if len(o.data) == 8:
123                        cookie.data = o.data + o.data
124                    else:
125                        cookie.data = o.data
126                    r.use_edns(options=[cookie])
127    r.flags |= dns.flags.AA
128    return r
129
130
131def sigterm(signum, frame):
132    print("Shutting down now...")
133    os.remove("ans.pid")
134    running = False
135    sys.exit(0)
136
137
138############################################################################
139# Main
140#
141# Set up responder and control channel, open the pid file, and start
142# the main loop, listening for queries on the query channel or commands
143# on the control channel and acting on them.
144############################################################################
145ip4_addr1 = "10.53.0.9"
146ip4_addr2 = "10.53.0.10"
147ip6_addr1 = "fd92:7065:b8e:ffff::9"
148ip6_addr2 = "fd92:7065:b8e:ffff::10"
149
150try:
151    port = int(os.environ["PORT"])
152except:
153    port = 5300
154
155query4_udp1 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
156query4_udp1.bind((ip4_addr1, port))
157query4_tcp1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
158query4_tcp1.bind((ip4_addr1, port))
159query4_tcp1.listen(1)
160query4_tcp1.settimeout(1)
161
162query4_udp2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
163query4_udp2.bind((ip4_addr2, port))
164query4_tcp2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
165query4_tcp2.bind((ip4_addr2, port))
166query4_tcp2.listen(1)
167query4_tcp2.settimeout(1)
168
169havev6 = True
170query6_udp1 = None
171query6_udp2 = None
172query6_tcp1 = None
173query6_tcp2 = None
174try:
175    query6_udp1 = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
176    query6_udp1.bind((ip6_addr1, port))
177    query6_tcp1 = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
178    query6_tcp1.bind((ip6_addr1, port))
179    query6_tcp1.listen(1)
180    query6_tcp1.settimeout(1)
181
182    query6_udp2 = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
183    query6_udp2.bind((ip6_addr2, port))
184    query6_tcp2 = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
185    query6_tcp2.bind((ip6_addr2, port))
186    query6_tcp2.listen(1)
187    query6_tcp2.settimeout(1)
188except:
189    if query6_udp1 != None:
190        query6_udp1.close()
191    if query6_tcp1 != None:
192        query6_tcp1.close()
193    if query6_udp2 != None:
194        query6_udp2.close()
195    if query6_tcp2 != None:
196        query6_tcp2.close()
197    havev6 = False
198
199signal.signal(signal.SIGTERM, sigterm)
200
201f = open("ans.pid", "w")
202pid = os.getpid()
203print(pid, file=f)
204f.close()
205
206running = True
207
208print("Using DNS version %s" % dns.version.version)
209print("Listening on %s port %d" % (ip4_addr1, port))
210print("Listening on %s port %d" % (ip4_addr2, port))
211if havev6:
212    print("Listening on %s port %d" % (ip6_addr1, port))
213    print("Listening on %s port %d" % (ip6_addr2, port))
214print("Ctrl-c to quit")
215
216if havev6:
217    input = [
218        query4_udp1,
219        query6_udp1,
220        query4_tcp1,
221        query6_tcp1,
222        query4_udp2,
223        query6_udp2,
224        query4_tcp2,
225        query6_tcp2,
226    ]
227else:
228    input = [query4_udp1, query4_tcp1, query4_udp2, query4_tcp2]
229
230while running:
231    try:
232        inputready, outputready, exceptready = select.select(input, [], [])
233    except select.error as e:
234        break
235    except socket.error as e:
236        break
237    except KeyboardInterrupt:
238        break
239
240    for s in inputready:
241        ns10 = False
242        if s == query4_udp1 or s == query6_udp1 or s == query4_udp2 or s == query6_udp2:
243            if s == query4_udp1 or s == query6_udp1:
244                print(
245                    "UDP Query received on %s"
246                    % (ip4_addr1 if s == query4_udp1 else ip6_addr1),
247                    end=" ",
248                )
249            if s == query4_udp2 or s == query6_udp2:
250                print(
251                    "UDP Query received on %s"
252                    % (ip4_addr2 if s == query4_udp2 else ip6_addr2),
253                    end=" ",
254                )
255                ns10 = True
256            # Handle incoming queries
257            msg = s.recvfrom(65535)
258            dopass2 = False
259            rsp = create_response(msg[0], False, True, ns10)
260            print(dns.rcode.to_text(rsp.rcode()))
261            s.sendto(rsp.to_wire(), msg[1])
262            if dopass2:
263                print("Sending second UDP response without TSIG", end=" ")
264                rsp = create_response(msg[0], False, False, ns10)
265                s.sendto(rsp.to_wire(), msg[1])
266                print(dns.rcode.to_text(rsp.rcode()))
267
268        if s == query4_tcp1 or s == query6_tcp1 or s == query4_tcp2 or s == query6_tcp2:
269            try:
270                (cs, _) = s.accept()
271                if s == query4_tcp1 or s == query6_tcp1:
272                    print(
273                        "TCP Query received on %s"
274                        % (ip4_addr1 if s == query4_tcp1 else ip6_addr1),
275                        end=" ",
276                    )
277                if s == query4_tcp2 or s == query6_tcp2:
278                    print(
279                        "TCP Query received on %s"
280                        % (ip4_addr2 if s == query4_tcp2 else ip6_addr2),
281                        end=" ",
282                    )
283                    ns10 = True
284                # get TCP message length
285                buf = cs.recv(2)
286                length = struct.unpack(">H", buf[:2])[0]
287                # grep DNS message
288                msg = cs.recv(length)
289                rsp = create_response(msg, True, True, ns10)
290                print(dns.rcode.to_text(rsp.rcode()))
291                wire = rsp.to_wire()
292                cs.send(struct.pack(">H", len(wire)))
293                cs.send(wire)
294                cs.close()
295            except s.timeout:
296                pass
297    if not running:
298        break
299