xref: /netbsd-src/external/mpl/bind/dist/bin/tests/system/pipelined/ans5/ans.py (revision f281902de12281841521aa31ef834ad944d725e2)
1# Copyright (C) Internet Systems Consortium, Inc. ("ISC")
2#
3# SPDX-License-Identifier: MPL-2.0
4#
5# This Source Code Form is subject to the terms of the Mozilla Public
6# License, v. 2.0.  If a copy of the MPL was not distributed with this
7# file, you can obtain one at https://mozilla.org/MPL/2.0/.
8#
9# See the COPYRIGHT file distributed with this work for additional
10# information regarding copyright ownership.
11
12############################################################################
13#
14# This tool acts as a TCP/UDP proxy and delays all incoming packets by 500
15# milliseconds.
16#
17# We use it to check pipelining - a client sents 8 questions over a
18# pipelined connection - that require asking a normal (examplea) and a
19# slow-responding (exampleb) servers:
20# a.examplea
21# a.exampleb
22# b.examplea
23# b.exampleb
24# c.examplea
25# c.exampleb
26# d.examplea
27# d.exampleb
28#
29# If pipelining works properly the answers will be returned out of order
30# with all answers from examplea returned first, and then all answers
31# from exampleb.
32#
33############################################################################
34
35from __future__ import print_function
36
37import datetime
38import os
39import select
40import signal
41import socket
42import sys
43import time
44import threading
45import struct
46
47DELAY = 0.5
48THREADS = []
49
50
51def log(msg):
52    print(datetime.datetime.now().strftime("%d-%b-%Y %H:%M:%S.%f ") + msg)
53
54
55def sigterm(*_):
56    log("SIGTERM received, shutting down")
57    for thread in THREADS:
58        thread.close()
59        thread.join()
60    os.remove("ans.pid")
61    sys.exit(0)
62
63
64class TCPDelayer(threading.Thread):
65    """For a given TCP connection conn we open a connection to (ip, port),
66    and then we delay each incoming packet by DELAY by putting it in a
67    queue.
68    In the pipelined test TCP should not be used, but it's here for
69    completnes.
70    """
71
72    def __init__(self, conn, ip, port):
73        threading.Thread.__init__(self)
74        self.conn = conn
75        self.cconn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
76        self.cconn.connect((ip, port))
77        self.queue = []
78        self.running = True
79
80    def close(self):
81        self.running = False
82
83    def run(self):
84        while self.running:
85            curr_timeout = 0.5
86            try:
87                curr_timeout = self.queue[0][0] - time.time()
88            except StopIteration:
89                pass
90            if curr_timeout > 0:
91                if curr_timeout == 0:
92                    curr_timeout = 0.5
93                rfds, _, _ = select.select(
94                    [self.conn, self.cconn], [], [], curr_timeout
95                )
96                if self.conn in rfds:
97                    data = self.conn.recv(65535)
98                    if not data:
99                        return
100                    self.queue.append((time.time() + DELAY, data))
101                if self.cconn in rfds:
102                    data = self.cconn.recv(65535)
103                    if not data == 0:
104                        return
105                    self.conn.send(data)
106            try:
107                while self.queue[0][0] - time.time() < 0:
108                    _, data = self.queue.pop(0)
109                    self.cconn.send(data)
110            except StopIteration:
111                pass
112
113
114class UDPDelayer(threading.Thread):
115    """Every incoming UDP packet is put in a queue for DELAY time, then
116    it's sent to (ip, port). We remember the query id to send the
117    response we get to a proper source, responses are not delayed.
118    """
119
120    def __init__(self, usock, ip, port):
121        threading.Thread.__init__(self)
122        self.sock = usock
123        self.csock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
124        self.dst = (ip, port)
125        self.queue = []
126        self.qid_mapping = {}
127        self.running = True
128
129    def close(self):
130        self.running = False
131
132    def run(self):
133        while self.running:
134            curr_timeout = 0.5
135            if self.queue:
136                curr_timeout = self.queue[0][0] - time.time()
137            if curr_timeout >= 0:
138                if curr_timeout == 0:
139                    curr_timeout = 0.5
140                rfds, _, _ = select.select(
141                    [self.sock, self.csock], [], [], curr_timeout
142                )
143                if self.sock in rfds:
144                    data, addr = self.sock.recvfrom(65535)
145                    if not data:
146                        return
147                    self.queue.append((time.time() + DELAY, data))
148                    qid = struct.unpack(">H", data[:2])[0]
149                    log("Received a query from %s, queryid %d" % (str(addr), qid))
150                    self.qid_mapping[qid] = addr
151                if self.csock in rfds:
152                    data, addr = self.csock.recvfrom(65535)
153                    if not data:
154                        return
155                    qid = struct.unpack(">H", data[:2])[0]
156                    dst = self.qid_mapping.get(qid)
157                    if dst is not None:
158                        self.sock.sendto(data, dst)
159                        log(
160                            "Received a response from %s, queryid %d, sending to %s"
161                            % (str(addr), qid, str(dst))
162                        )
163            while self.queue and self.queue[0][0] - time.time() < 0:
164                _, data = self.queue.pop(0)
165                qid = struct.unpack(">H", data[:2])[0]
166                log("Sending a query to %s, queryid %d" % (str(self.dst), qid))
167                self.csock.sendto(data, self.dst)
168
169
170def main():
171    signal.signal(signal.SIGTERM, sigterm)
172    signal.signal(signal.SIGINT, sigterm)
173
174    with open("ans.pid", "w") as pidfile:
175        print(os.getpid(), file=pidfile)
176
177    listenip = "10.53.0.5"
178    serverip = "10.53.0.2"
179
180    try:
181        port = int(os.environ["PORT"])
182    except KeyError:
183        port = 5300
184
185    log("Listening on %s:%d" % (listenip, port))
186
187    usock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
188    usock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
189    usock.bind((listenip, port))
190    thread = UDPDelayer(usock, serverip, port)
191    thread.start()
192    THREADS.append(thread)
193
194    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
195    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
196    sock.bind((listenip, port))
197    sock.listen(1)
198    sock.settimeout(1)
199
200    while True:
201        try:
202            (clientsock, _) = sock.accept()
203            log("Accepted connection from %s" % clientsock)
204            thread = TCPDelayer(clientsock, serverip, port)
205            thread.start()
206            THREADS.append(thread)
207        except socket.timeout:
208            pass
209
210
211if __name__ == "__main__":
212    main()
213