xref: /dpdk/usertools/dpdk-rss-flows.py (revision 7a86a806dcf32213171adc9dc36d87b3d0c2750b)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: BSD-3-Clause
3# Copyright (c) 2014 6WIND S.A.
4# Copyright (c) 2023 Robin Jarry
5
6"""
7Craft IP{v6}/{TCP/UDP} traffic flows that will evenly spread over a given
8number of RX queues according to the RSS algorithm.
9"""
10
11import argparse
12import binascii
13import ctypes
14import ipaddress
15import json
16import struct
17import typing
18
19
20Address = typing.Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
21Network = typing.Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
22PortList = typing.Iterable[int]
23
24
25class Packet:
26    def __init__(self, ip_src: Address, ip_dst: Address, l4_sport: int, l4_dport: int):
27        self.ip_src = ip_src
28        self.ip_dst = ip_dst
29        self.l4_sport = l4_sport
30        self.l4_dport = l4_dport
31
32    def reverse(self):
33        return Packet(
34            ip_src=self.ip_dst,
35            l4_sport=self.l4_dport,
36            ip_dst=self.ip_src,
37            l4_dport=self.l4_sport,
38        )
39
40    def hash_data(self, use_l4_port: bool = False) -> bytes:
41        data = self.ip_src.packed + self.ip_dst.packed
42        if use_l4_port:
43            data += struct.pack(">H", self.l4_sport)
44            data += struct.pack(">H", self.l4_dport)
45        return data
46
47
48class TrafficTemplate:
49    def __init__(
50        self,
51        ip_src: Network,
52        ip_dst: Network,
53        l4_sport_range: PortList,
54        l4_dport_range: PortList,
55    ):
56        self.ip_src = ip_src
57        self.ip_dst = ip_dst
58        self.l4_sport_range = l4_sport_range
59        self.l4_dport_range = l4_dport_range
60
61    def __iter__(self) -> typing.Iterator[Packet]:
62        for ip_src in self.ip_src.hosts():
63            for ip_dst in self.ip_dst.hosts():
64                if ip_src == ip_dst:
65                    continue
66                for sport in self.l4_sport_range:
67                    for dport in self.l4_dport_range:
68                        yield Packet(ip_src, ip_dst, sport, dport)
69
70
71class RSSAlgo:
72    def __init__(
73        self,
74        queues_count: int,
75        key: bytes,
76        reta_size: int,
77        use_l4_port: bool,
78    ):
79        self.queues_count = queues_count
80        self.reta = tuple(i % queues_count for i in range(reta_size))
81        self.key = key
82        self.use_l4_port = use_l4_port
83
84    def toeplitz_hash(self, data: bytes) -> int:
85        # see rte_softrss_* in lib/hash/rte_thash.h
86        hash_value = ctypes.c_uint32(0)
87
88        for i, byte in enumerate(data):
89            for j in range(8):
90                bit = (byte >> (7 - j)) & 0x01
91
92                if bit == 1:
93                    keyword = ctypes.c_uint32(0)
94                    keyword.value |= self.key[i] << 24
95                    keyword.value |= self.key[i + 1] << 16
96                    keyword.value |= self.key[i + 2] << 8
97                    keyword.value |= self.key[i + 3]
98
99                    if j > 0:
100                        keyword.value <<= j
101                        keyword.value |= self.key[i + 4] >> (8 - j)
102
103                    hash_value.value ^= keyword.value
104
105        return hash_value.value
106
107    def get_queue_index(self, packet: Packet) -> int:
108        bytes_to_hash = packet.hash_data(self.use_l4_port)
109
110        # get the 32bit hash of the packet
111        hash_value = self.toeplitz_hash(bytes_to_hash)
112
113        # determine the offset in the redirection table
114        offset = hash_value & (len(self.reta) - 1)
115
116        return self.reta[offset]
117
118
119def balanced_traffic(
120    algo: RSSAlgo,
121    traffic_template: TrafficTemplate,
122    check_reverse_traffic: bool = False,
123    all_flows: bool = False,
124) -> typing.Iterator[typing.Tuple[int, int, Packet]]:
125    queues = set()
126    if check_reverse_traffic:
127        queues_reverse = set()
128
129    for pkt in traffic_template:
130        q = algo.get_queue_index(pkt)
131
132        # check if q is already filled
133        if not all_flows and q in queues:
134            continue
135
136        qr = algo.get_queue_index(pkt.reverse())
137
138        if check_reverse_traffic:
139            # check if q is already filled
140            if not all_flows and qr in queues_reverse:
141                continue
142            # mark this queue as matched
143            queues_reverse.add(qr)
144
145        # mark this queue as filled
146        queues.add(q)
147
148        yield q, qr, pkt
149
150        # stop when all queues have been filled
151        if not all_flows and len(queues) == algo.queues_count:
152            break
153
154
155NO_PORT = (0,)
156
157# fmt: off
158# rss_intel_key, see drivers/net/ixgbe/ixgbe_rxtx.c
159RSS_KEY_INTEL = bytes(
160    (
161        0x6d, 0x5a, 0x56, 0xda, 0x25, 0x5b, 0x0e, 0xc2,
162        0x41, 0x67, 0x25, 0x3d, 0x43, 0xa3, 0x8f, 0xb0,
163        0xd0, 0xca, 0x2b, 0xcb, 0xae, 0x7b, 0x30, 0xb4,
164        0x77, 0xcb, 0x2d, 0xa3, 0x80, 0x30, 0xf2, 0x0c,
165        0x6a, 0x42, 0xb7, 0x3b, 0xbe, 0xac, 0x01, 0xfa,
166    )
167)
168# rss_hash_default_key, see drivers/net/mlx5/mlx5_rxq.c
169RSS_KEY_MLX = bytes(
170    (
171        0x2c, 0xc6, 0x81, 0xd1, 0x5b, 0xdb, 0xf4, 0xf7,
172        0xfc, 0xa2, 0x83, 0x19, 0xdb, 0x1a, 0x3e, 0x94,
173        0x6b, 0x9e, 0x38, 0xd9, 0x2c, 0x9c, 0x03, 0xd1,
174        0xad, 0x99, 0x44, 0xa7, 0xd9, 0x56, 0x3d, 0x59,
175        0x06, 0x3c, 0x25, 0xf3, 0xfc, 0x1f, 0xdc, 0x2a,
176    )
177)
178# rss_key_default, see drivers/net/i40e/i40e_ethdev.c
179# i40e is the only driver that takes 52 bytes keys
180RSS_KEY_I40E = bytes(
181    (
182        0x44, 0x39, 0x79, 0x6b, 0xb5, 0x4c, 0x50, 0x23,
183        0xb6, 0x75, 0xea, 0x5b, 0x12, 0x4f, 0x9f, 0x30,
184        0xb8, 0xa2, 0xc0, 0x3d, 0xdf, 0xdc, 0x4d, 0x02,
185        0xa0, 0x8c, 0x9b, 0x33, 0x4a, 0xf6, 0x4a, 0x4c,
186        0x05, 0xc6, 0xfa, 0x34, 0x39, 0x58, 0xd8, 0x55,
187        0x7d, 0x99, 0x58, 0x3a, 0xe1, 0x38, 0xc9, 0x2e,
188        0x81, 0x15, 0x03, 0x66,
189    )
190)
191# fmt: on
192DEFAULT_DRIVER_KEYS = {
193    "intel": RSS_KEY_INTEL,
194    "mlx": RSS_KEY_MLX,
195    "i40e": RSS_KEY_I40E,
196}
197
198
199def rss_key(value):
200    if value in DEFAULT_DRIVER_KEYS:
201        return DEFAULT_DRIVER_KEYS[value]
202    try:
203        key = binascii.unhexlify(value)
204        if len(key) not in (40, 52):
205            raise argparse.ArgumentTypeError("The key must be 40 or 52 bytes long")
206        return key
207    except (TypeError, ValueError) as e:
208        raise argparse.ArgumentTypeError(str(e)) from e
209
210
211def port_range(value):
212    try:
213        if "-" in value:
214            start, stop = value.split("-")
215            res = tuple(range(int(start), int(stop)))
216        else:
217            res = (int(value),)
218        return res or NO_PORT
219    except ValueError as e:
220        raise argparse.ArgumentTypeError(str(e)) from e
221
222
223def positive_int(value):
224    try:
225        i = int(value)
226        if i <= 0:
227            raise argparse.ArgumentTypeError("must be strictly positive")
228        return i
229    except ValueError as e:
230        raise argparse.ArgumentTypeError(str(e)) from e
231
232
233def power_of_two(value):
234    i = positive_int(value)
235    if i & (i - 1) != 0:
236        raise argparse.ArgumentTypeError("must be a power of two")
237    return i
238
239
240def parse_args():
241    parser = argparse.ArgumentParser(description=__doc__)
242
243    parser.add_argument(
244        "rx_queues",
245        metavar="RX_QUEUES",
246        type=positive_int,
247        help="""
248        The number of RX queues to fill.
249        """,
250    )
251    parser.add_argument(
252        "ip_src",
253        metavar="SRC",
254        type=ipaddress.ip_network,
255        help="""
256        The source IP network/address.
257        """,
258    )
259    parser.add_argument(
260        "ip_dst",
261        metavar="DST",
262        type=ipaddress.ip_network,
263        help="""
264        The destination IP network/address.
265        """,
266    )
267    parser.add_argument(
268        "-s",
269        "--sport-range",
270        type=port_range,
271        default=NO_PORT,
272        help="""
273        The layer 4 (TCP/UDP) source port range.
274        Can be a single fixed value or a range <start>-<end>.
275        """,
276    )
277    parser.add_argument(
278        "-d",
279        "--dport-range",
280        type=port_range,
281        default=NO_PORT,
282        help="""
283        The layer 4 (TCP/UDP) destination port range.
284        Can be a single fixed value or a range <start>-<end>.
285        """,
286    )
287    parser.add_argument(
288        "-r",
289        "--check-reverse-traffic",
290        action="store_true",
291        help="""
292        The reversed traffic (source <-> dest) should also be evenly balanced
293        in the queues.
294        """,
295    )
296    parser.add_argument(
297        "-k",
298        "--rss-key",
299        default=RSS_KEY_INTEL,
300        type=rss_key,
301        help="""
302        The random 40-bytes key used to compute the RSS hash. This option
303        supports either a well-known name or the hex value of the key
304        (well-known names: "intel", "mlx", default: "intel").
305        """,
306    )
307    parser.add_argument(
308        "-t",
309        "--reta-size",
310        default=128,
311        type=power_of_two,
312        help="""
313        Size of the redirection table or "RETA" (default: 128).
314        """,
315    )
316    parser.add_argument(
317        "-a",
318        "--all-flows",
319        action="store_true",
320        help="""
321        Output ALL flows that can be created based on source and destination
322        address/port ranges along their matched queue number. ATTENTION: this
323        option can produce very long outputs depending on the address and port
324        range sizes.
325        """,
326    )
327    parser.add_argument(
328        "-j",
329        "--json",
330        action="store_true",
331        help="""
332        Output in parseable JSON format.
333        """,
334    )
335
336    args = parser.parse_args()
337
338    if args.ip_src.version != args.ip_dst.version:
339        parser.error(
340            f"{args.ip_src} and {args.ip_dst} don't have the same protocol version"
341        )
342    if args.reta_size < args.rx_queues:
343        parser.error("RETA_SIZE must be greater than or equal to RX_QUEUES")
344
345    return args
346
347
348def main():
349    args = parse_args()
350    use_l4_port = args.sport_range != NO_PORT or args.dport_range != NO_PORT
351
352    algo = RSSAlgo(
353        queues_count=args.rx_queues,
354        key=args.rss_key,
355        reta_size=args.reta_size,
356        use_l4_port=use_l4_port,
357    )
358    template = TrafficTemplate(
359        args.ip_src,
360        args.ip_dst,
361        args.sport_range,
362        args.dport_range,
363    )
364
365    results = balanced_traffic(
366        algo, template, args.check_reverse_traffic, args.all_flows
367    )
368
369    if args.json:
370        flows = []
371        for q, qr, pkt in results:
372            flows.append(
373                {
374                    "queue": q,
375                    "queue_reverse": qr,
376                    "src_ip": str(pkt.ip_src),
377                    "dst_ip": str(pkt.ip_dst),
378                    "src_port": pkt.l4_sport,
379                    "dst_port": pkt.l4_dport,
380                }
381            )
382        print(json.dumps(flows, indent=2))
383        return
384
385    if use_l4_port:
386        header = ["SRC_IP", "SPORT", "DST_IP", "DPORT", "QUEUE"]
387    else:
388        header = ["SRC_IP", "DST_IP", "QUEUE"]
389    if args.check_reverse_traffic:
390        header.append("QUEUE_REVERSE")
391
392    rows = [tuple(header)]
393    widths = [len(h) for h in header]
394
395    for q, qr, pkt in results:
396        if use_l4_port:
397            row = [pkt.ip_src, pkt.l4_sport, pkt.ip_dst, pkt.l4_dport, q]
398        else:
399            row = [pkt.ip_src, pkt.ip_dst, q]
400        if args.check_reverse_traffic:
401            row.append(qr)
402        cells = []
403        for i, r in enumerate(row):
404            r = str(r)
405            if len(r) > widths[i]:
406                widths[i] = len(r)
407            cells.append(r)
408        rows.append(tuple(cells))
409
410    fmt = [f"%-{w}s" for w in widths]
411    fmt[-1] = "%s"  # avoid trailing whitespace
412    fmt = "    ".join(fmt)
413    for row in rows:
414        print(fmt % row)
415
416
417if __name__ == "__main__":
418    main()
419