xref: /dpdk/usertools/dpdk-rss-flows.py (revision c1d145834f287aa8cf53de914618a7312f2c360e)
1#!/usr/bin/env python3
2# SPDX-License-Identifier: BSD-3-Clause
3# Copyright (c) 2014 6WIND S.A.
4# Copyright (c) 2023 Robin Jarry
5
6"""
7Craft IP{v6}/{TCP/UDP} traffic flows that will evenly spread over a given
8number of RX queues according to the RSS algorithm.
9"""
10
11import argparse
12import binascii
13import ctypes
14import ipaddress
15import json
16import struct
17import typing
18
19
20Address = typing.Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
21Network = typing.Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
22PortList = typing.Iterable[int]
23
24
25class Packet:
26    def __init__(self, ip_src: Address, ip_dst: Address, l4_sport: int, l4_dport: int):
27        self.ip_src = ip_src
28        self.ip_dst = ip_dst
29        self.l4_sport = l4_sport
30        self.l4_dport = l4_dport
31
32    def reverse(self):
33        return Packet(
34            ip_src=self.ip_dst,
35            l4_sport=self.l4_dport,
36            ip_dst=self.ip_src,
37            l4_dport=self.l4_sport,
38        )
39
40    def hash_data(self, use_l4_port: bool = False) -> bytes:
41        data = self.ip_src.packed + self.ip_dst.packed
42        if use_l4_port:
43            data += struct.pack(">H", self.l4_sport)
44            data += struct.pack(">H", self.l4_dport)
45        return data
46
47
48class TrafficTemplate:
49    def __init__(
50        self,
51        ip_src: Network,
52        ip_dst: Network,
53        l4_sport_range: PortList,
54        l4_dport_range: PortList,
55    ):
56        self.ip_src = ip_src
57        self.ip_dst = ip_dst
58        self.l4_sport_range = l4_sport_range
59        self.l4_dport_range = l4_dport_range
60
61    def __iter__(self) -> typing.Iterator[Packet]:
62        for ip_src in self.ip_src.hosts():
63            for ip_dst in self.ip_dst.hosts():
64                if ip_src == ip_dst:
65                    continue
66                for sport in self.l4_sport_range:
67                    for dport in self.l4_dport_range:
68                        yield Packet(ip_src, ip_dst, sport, dport)
69
70
71class RSSAlgo:
72    def __init__(
73        self,
74        queues_count: int,
75        key: bytes,
76        reta_size: int,
77        use_l4_port: bool,
78    ):
79        self.queues_count = queues_count
80        self.reta = tuple(i % queues_count for i in range(reta_size))
81        self.key = key
82        self.use_l4_port = use_l4_port
83
84    def toeplitz_hash(self, data: bytes) -> int:
85        # see rte_softrss_* in lib/hash/rte_thash.h
86        hash_value = ctypes.c_uint32(0)
87
88        for i, byte in enumerate(data):
89            for j in range(8):
90                bit = (byte >> (7 - j)) & 0x01
91
92                if bit == 1:
93                    keyword = ctypes.c_uint32(0)
94                    keyword.value |= self.key[i] << 24
95                    keyword.value |= self.key[i + 1] << 16
96                    keyword.value |= self.key[i + 2] << 8
97                    keyword.value |= self.key[i + 3]
98
99                    if j > 0:
100                        keyword.value <<= j
101                        keyword.value |= self.key[i + 4] >> (8 - j)
102
103                    hash_value.value ^= keyword.value
104
105        return hash_value.value
106
107    def get_queue_index(self, packet: Packet) -> int:
108        bytes_to_hash = packet.hash_data(self.use_l4_port)
109
110        # get the 32bit hash of the packet
111        hash_value = self.toeplitz_hash(bytes_to_hash)
112
113        # determine the offset in the redirection table
114        offset = hash_value & (len(self.reta) - 1)
115
116        return self.reta[offset]
117
118
119def balanced_traffic(
120    algo: RSSAlgo,
121    traffic_template: TrafficTemplate,
122    check_reverse_traffic: bool = False,
123    all_flows: bool = False,
124) -> typing.Iterator[typing.Tuple[int, int, Packet]]:
125    queues = set()
126    if check_reverse_traffic:
127        queues_reverse = set()
128
129    for pkt in traffic_template:
130        q = algo.get_queue_index(pkt)
131
132        # check if q is already filled
133        if not all_flows and q in queues:
134            continue
135
136        qr = algo.get_queue_index(pkt.reverse())
137
138        if check_reverse_traffic:
139            # check if q is already filled
140            if not all_flows and qr in queues_reverse:
141                continue
142            # mark this queue as matched
143            queues_reverse.add(qr)
144
145        # mark this queue as filled
146        queues.add(q)
147
148        yield q, qr, pkt
149
150        # stop when all queues have been filled
151        if not all_flows and len(queues) == algo.queues_count:
152            break
153
154
155NO_PORT = (0,)
156
157
158class DriverInfo:
159    def __init__(self, key: bytes = None, reta_size: int = None):
160        self.__key = key
161        self.__reta_size = reta_size
162
163    def rss_key(self) -> bytes:
164        return self.__key
165
166    def reta_size(self, num_queues: int) -> int:
167        return self.__reta_size
168
169
170class MlxDriverInfo(DriverInfo):
171    def rss_key(self) -> bytes:
172        return bytes(
173            (
174                # fmt: off
175                # rss_hash_default_key, see drivers/net/mlx5/mlx5_rxq.c
176                0x2c, 0xc6, 0x81, 0xd1, 0x5b, 0xdb, 0xf4, 0xf7,
177                0xfc, 0xa2, 0x83, 0x19, 0xdb, 0x1a, 0x3e, 0x94,
178                0x6b, 0x9e, 0x38, 0xd9, 0x2c, 0x9c, 0x03, 0xd1,
179                0xad, 0x99, 0x44, 0xa7, 0xd9, 0x56, 0x3d, 0x59,
180                0x06, 0x3c, 0x25, 0xf3, 0xfc, 0x1f, 0xdc, 0x2a,
181                # fmt: on
182            )
183        )
184
185    def reta_size(self, num_queues: int) -> int:
186        if num_queues & (num_queues - 1) == 0:
187            # If the requested number of RX queues is power of two,
188            # use a table of this size.
189            return num_queues
190        # otherwise, use the maximum table size
191        return 512
192
193
194DEFAULT_DRIVERS = {
195    "cnxk": DriverInfo(
196        key=bytes(
197            (
198                # fmt: off
199                # roc_nix_rss_key_default_fill, see drivers/common/cnxk/roc_nix_rss.c
200                # Marvell cnxk NICs take 48 bytes keys
201                0xfe, 0xed, 0x0b, 0xad, 0xfe, 0xed, 0x0b, 0xad,
202                0xfe, 0xed, 0x0b, 0xad, 0xfe, 0xed, 0x0b, 0xad,
203                0xfe, 0xed, 0x0b, 0xad, 0xfe, 0xed, 0x0b, 0xad,
204                0xfe, 0xed, 0x0b, 0xad, 0xfe, 0xed, 0x0b, 0xad,
205                0xfe, 0xed, 0x0b, 0xad, 0xfe, 0xed, 0x0b, 0xad,
206                0xfe, 0xed, 0x0b, 0xad, 0xfe, 0xed, 0x0b, 0xad,
207                # fmt: on
208            )
209        ),
210        reta_size=64,
211    ),
212    "intel": DriverInfo(
213        key=bytes(
214            (
215                # fmt: off
216                # rss_intel_key, see drivers/net/intel/ixgbe/ixgbe_rxtx.c
217                0x6d, 0x5a, 0x56, 0xda, 0x25, 0x5b, 0x0e, 0xc2,
218                0x41, 0x67, 0x25, 0x3d, 0x43, 0xa3, 0x8f, 0xb0,
219                0xd0, 0xca, 0x2b, 0xcb, 0xae, 0x7b, 0x30, 0xb4,
220                0x77, 0xcb, 0x2d, 0xa3, 0x80, 0x30, 0xf2, 0x0c,
221                0x6a, 0x42, 0xb7, 0x3b, 0xbe, 0xac, 0x01, 0xfa,
222                # fmt: on
223            )
224        ),
225        reta_size=128,
226    ),
227    "i40e": DriverInfo(
228        key=bytes(
229            (
230                # fmt: off
231                # rss_key_default, see drivers/net/intel/i40e/i40e_ethdev.c
232                # i40e is the only driver that takes 52 bytes keys
233                0x44, 0x39, 0x79, 0x6b, 0xb5, 0x4c, 0x50, 0x23,
234                0xb6, 0x75, 0xea, 0x5b, 0x12, 0x4f, 0x9f, 0x30,
235                0xb8, 0xa2, 0xc0, 0x3d, 0xdf, 0xdc, 0x4d, 0x02,
236                0xa0, 0x8c, 0x9b, 0x33, 0x4a, 0xf6, 0x4a, 0x4c,
237                0x05, 0xc6, 0xfa, 0x34, 0x39, 0x58, 0xd8, 0x55,
238                0x7d, 0x99, 0x58, 0x3a, 0xe1, 0x38, 0xc9, 0x2e,
239                0x81, 0x15, 0x03, 0x66,
240                # fmt: on
241            )
242        ),
243        reta_size=512,
244    ),
245    "mlx": MlxDriverInfo(),
246}
247
248
249def port_range(value):
250    try:
251        if "-" in value:
252            start, stop = value.split("-")
253            res = tuple(range(int(start), int(stop)))
254        else:
255            res = (int(value),)
256        return res or NO_PORT
257    except ValueError as e:
258        raise argparse.ArgumentTypeError(str(e)) from e
259
260
261def positive_int(value):
262    try:
263        i = int(value)
264        if i <= 0:
265            raise argparse.ArgumentTypeError("must be strictly positive")
266        return i
267    except ValueError as e:
268        raise argparse.ArgumentTypeError(str(e)) from e
269
270
271def power_of_two(value):
272    i = positive_int(value)
273    if i & (i - 1) != 0:
274        raise argparse.ArgumentTypeError("must be a power of two")
275    return i
276
277
278def parse_args():
279    parser = argparse.ArgumentParser(description=__doc__)
280
281    parser.add_argument(
282        "rx_queues",
283        metavar="RX_QUEUES",
284        type=positive_int,
285        help="""
286        The number of RX queues to fill.
287        """,
288    )
289    parser.add_argument(
290        "ip_src",
291        metavar="SRC",
292        type=ipaddress.ip_network,
293        help="""
294        The source IP network/address.
295        """,
296    )
297    parser.add_argument(
298        "ip_dst",
299        metavar="DST",
300        type=ipaddress.ip_network,
301        help="""
302        The destination IP network/address.
303        """,
304    )
305    parser.add_argument(
306        "-s",
307        "--sport-range",
308        type=port_range,
309        default=NO_PORT,
310        help="""
311        The layer 4 (TCP/UDP) source port range.
312        Can be a single fixed value or a range <start>-<end>.
313        """,
314    )
315    parser.add_argument(
316        "-d",
317        "--dport-range",
318        type=port_range,
319        default=NO_PORT,
320        help="""
321        The layer 4 (TCP/UDP) destination port range.
322        Can be a single fixed value or a range <start>-<end>.
323        """,
324    )
325    parser.add_argument(
326        "-r",
327        "--check-reverse-traffic",
328        action="store_true",
329        help="""
330        The reversed traffic (source <-> dest) should also be evenly balanced
331        in the queues.
332        """,
333    )
334    parser.add_argument(
335        "-k",
336        "--rss-key",
337        default="intel",
338        help=f"""
339        The random key used to compute the RSS hash. This option
340        supports either a well-known name or the hex value of the key
341        (well-known names: {', '.join(DEFAULT_DRIVERS)}, default: intel).
342        """,
343    )
344    parser.add_argument(
345        "-t",
346        "--reta-size",
347        type=power_of_two,
348        help="""
349        Size of the redirection table or "RETA" (default: depends on driver if
350        using a well-known driver name, otherwise 128).
351        """,
352    )
353    parser.add_argument(
354        "-a",
355        "--all-flows",
356        action="store_true",
357        help="""
358        Output ALL flows that can be created based on source and destination
359        address/port ranges along their matched queue number. ATTENTION: this
360        option can produce very long outputs depending on the address and port
361        range sizes.
362        """,
363    )
364    parser.add_argument(
365        "-j",
366        "--json",
367        action="store_true",
368        help="""
369        Output in parseable JSON format.
370        """,
371    )
372    parser.add_argument(
373        "-i",
374        "--info",
375        action="store_true",
376        help="""
377        Print RETA size and RSS key above the results. Not available with --json.
378        """,
379    )
380
381    args = parser.parse_args()
382
383    if args.ip_src.version != args.ip_dst.version:
384        parser.error(
385            f"{args.ip_src} and {args.ip_dst} don't have the same protocol version"
386        )
387
388    if args.json and args.info:
389        parser.error("--json and --info are mutually exclusive")
390
391    if args.rss_key in DEFAULT_DRIVERS:
392        driver_info = DEFAULT_DRIVERS[args.rss_key]
393    else:
394        try:
395            key = binascii.unhexlify(args.rss_key)
396        except (TypeError, ValueError) as e:
397            parser.error(f"RSS_KEY: {e}")
398        driver_info = DriverInfo(key=key, reta_size=128)
399
400    if args.reta_size is None:
401        args.reta_size = driver_info.reta_size(args.rx_queues)
402
403    if args.reta_size < args.rx_queues:
404        parser.error("RETA_SIZE must be greater than or equal to RX_QUEUES")
405
406    args.rss_key = driver_info.rss_key()
407
408    return args
409
410
411def main():
412    args = parse_args()
413    use_l4_port = args.sport_range != NO_PORT or args.dport_range != NO_PORT
414
415    algo = RSSAlgo(
416        queues_count=args.rx_queues,
417        key=args.rss_key,
418        reta_size=args.reta_size,
419        use_l4_port=use_l4_port,
420    )
421    template = TrafficTemplate(
422        args.ip_src,
423        args.ip_dst,
424        args.sport_range,
425        args.dport_range,
426    )
427
428    results = balanced_traffic(
429        algo, template, args.check_reverse_traffic, args.all_flows
430    )
431
432    if args.json:
433        flows = []
434        for q, qr, pkt in results:
435            flows.append(
436                {
437                    "queue": q,
438                    "queue_reverse": qr,
439                    "src_ip": str(pkt.ip_src),
440                    "dst_ip": str(pkt.ip_dst),
441                    "src_port": pkt.l4_sport,
442                    "dst_port": pkt.l4_dport,
443                }
444            )
445        print(json.dumps(flows, indent=2))
446        return
447
448    if use_l4_port:
449        header = ["SRC_IP", "SPORT", "DST_IP", "DPORT", "QUEUE"]
450    else:
451        header = ["SRC_IP", "DST_IP", "QUEUE"]
452    if args.check_reverse_traffic:
453        header.append("QUEUE_REVERSE")
454
455    rows = [tuple(header)]
456    widths = [len(h) for h in header]
457
458    for q, qr, pkt in results:
459        if use_l4_port:
460            row = [pkt.ip_src, pkt.l4_sport, pkt.ip_dst, pkt.l4_dport, q]
461        else:
462            row = [pkt.ip_src, pkt.ip_dst, q]
463        if args.check_reverse_traffic:
464            row.append(qr)
465        cells = []
466        for i, r in enumerate(row):
467            r = str(r)
468            if len(r) > widths[i]:
469                widths[i] = len(r)
470            cells.append(r)
471        rows.append(tuple(cells))
472
473    if args.info:
474        print(f"RSS key:     {binascii.hexlify(args.rss_key).decode()}")
475        print(f"RETA size:   {args.reta_size}")
476        print()
477
478    fmt = [f"%-{w}s" for w in widths]
479    fmt[-1] = "%s"  # avoid trailing whitespace
480    fmt = "    ".join(fmt)
481    for row in rows:
482        print(fmt % row)
483
484
485if __name__ == "__main__":
486    main()
487