1#!/usr/bin/env python3 2# SPDX-License-Identifier: BSD-3-Clause 3# Copyright (c) 2014 6WIND S.A. 4# Copyright (c) 2023 Robin Jarry 5 6""" 7Craft IP{v6}/{TCP/UDP} traffic flows that will evenly spread over a given 8number of RX queues according to the RSS algorithm. 9""" 10 11import argparse 12import binascii 13import ctypes 14import ipaddress 15import json 16import struct 17import typing 18 19 20Address = typing.Union[ipaddress.IPv4Address, ipaddress.IPv6Address] 21Network = typing.Union[ipaddress.IPv4Network, ipaddress.IPv6Network] 22PortList = typing.Iterable[int] 23 24 25class Packet: 26 def __init__(self, ip_src: Address, ip_dst: Address, l4_sport: int, l4_dport: int): 27 self.ip_src = ip_src 28 self.ip_dst = ip_dst 29 self.l4_sport = l4_sport 30 self.l4_dport = l4_dport 31 32 def reverse(self): 33 return Packet( 34 ip_src=self.ip_dst, 35 l4_sport=self.l4_dport, 36 ip_dst=self.ip_src, 37 l4_dport=self.l4_sport, 38 ) 39 40 def hash_data(self, use_l4_port: bool = False) -> bytes: 41 data = self.ip_src.packed + self.ip_dst.packed 42 if use_l4_port: 43 data += struct.pack(">H", self.l4_sport) 44 data += struct.pack(">H", self.l4_dport) 45 return data 46 47 48class TrafficTemplate: 49 def __init__( 50 self, 51 ip_src: Network, 52 ip_dst: Network, 53 l4_sport_range: PortList, 54 l4_dport_range: PortList, 55 ): 56 self.ip_src = ip_src 57 self.ip_dst = ip_dst 58 self.l4_sport_range = l4_sport_range 59 self.l4_dport_range = l4_dport_range 60 61 def __iter__(self) -> typing.Iterator[Packet]: 62 for ip_src in self.ip_src.hosts(): 63 for ip_dst in self.ip_dst.hosts(): 64 if ip_src == ip_dst: 65 continue 66 for sport in self.l4_sport_range: 67 for dport in self.l4_dport_range: 68 yield Packet(ip_src, ip_dst, sport, dport) 69 70 71class RSSAlgo: 72 def __init__( 73 self, 74 queues_count: int, 75 key: bytes, 76 reta_size: int, 77 use_l4_port: bool, 78 ): 79 self.queues_count = queues_count 80 self.reta = tuple(i % queues_count for i in range(reta_size)) 81 self.key = key 82 self.use_l4_port = use_l4_port 83 84 def toeplitz_hash(self, data: bytes) -> int: 85 # see rte_softrss_* in lib/hash/rte_thash.h 86 hash_value = ctypes.c_uint32(0) 87 88 for i, byte in enumerate(data): 89 for j in range(8): 90 bit = (byte >> (7 - j)) & 0x01 91 92 if bit == 1: 93 keyword = ctypes.c_uint32(0) 94 keyword.value |= self.key[i] << 24 95 keyword.value |= self.key[i + 1] << 16 96 keyword.value |= self.key[i + 2] << 8 97 keyword.value |= self.key[i + 3] 98 99 if j > 0: 100 keyword.value <<= j 101 keyword.value |= self.key[i + 4] >> (8 - j) 102 103 hash_value.value ^= keyword.value 104 105 return hash_value.value 106 107 def get_queue_index(self, packet: Packet) -> int: 108 bytes_to_hash = packet.hash_data(self.use_l4_port) 109 110 # get the 32bit hash of the packet 111 hash_value = self.toeplitz_hash(bytes_to_hash) 112 113 # determine the offset in the redirection table 114 offset = hash_value & (len(self.reta) - 1) 115 116 return self.reta[offset] 117 118 119def balanced_traffic( 120 algo: RSSAlgo, 121 traffic_template: TrafficTemplate, 122 check_reverse_traffic: bool = False, 123 all_flows: bool = False, 124) -> typing.Iterator[typing.Tuple[int, int, Packet]]: 125 queues = set() 126 if check_reverse_traffic: 127 queues_reverse = set() 128 129 for pkt in traffic_template: 130 q = algo.get_queue_index(pkt) 131 132 # check if q is already filled 133 if not all_flows and q in queues: 134 continue 135 136 qr = algo.get_queue_index(pkt.reverse()) 137 138 if check_reverse_traffic: 139 # check if q is already filled 140 if not all_flows and qr in queues_reverse: 141 continue 142 # mark this queue as matched 143 queues_reverse.add(qr) 144 145 # mark this queue as filled 146 queues.add(q) 147 148 yield q, qr, pkt 149 150 # stop when all queues have been filled 151 if not all_flows and len(queues) == algo.queues_count: 152 break 153 154 155NO_PORT = (0,) 156 157# fmt: off 158# rss_intel_key, see drivers/net/ixgbe/ixgbe_rxtx.c 159RSS_KEY_INTEL = bytes( 160 ( 161 0x6d, 0x5a, 0x56, 0xda, 0x25, 0x5b, 0x0e, 0xc2, 162 0x41, 0x67, 0x25, 0x3d, 0x43, 0xa3, 0x8f, 0xb0, 163 0xd0, 0xca, 0x2b, 0xcb, 0xae, 0x7b, 0x30, 0xb4, 164 0x77, 0xcb, 0x2d, 0xa3, 0x80, 0x30, 0xf2, 0x0c, 165 0x6a, 0x42, 0xb7, 0x3b, 0xbe, 0xac, 0x01, 0xfa, 166 ) 167) 168# rss_hash_default_key, see drivers/net/mlx5/mlx5_rxq.c 169RSS_KEY_MLX = bytes( 170 ( 171 0x2c, 0xc6, 0x81, 0xd1, 0x5b, 0xdb, 0xf4, 0xf7, 172 0xfc, 0xa2, 0x83, 0x19, 0xdb, 0x1a, 0x3e, 0x94, 173 0x6b, 0x9e, 0x38, 0xd9, 0x2c, 0x9c, 0x03, 0xd1, 174 0xad, 0x99, 0x44, 0xa7, 0xd9, 0x56, 0x3d, 0x59, 175 0x06, 0x3c, 0x25, 0xf3, 0xfc, 0x1f, 0xdc, 0x2a, 176 ) 177) 178# rss_key_default, see drivers/net/i40e/i40e_ethdev.c 179# i40e is the only driver that takes 52 bytes keys 180RSS_KEY_I40E = bytes( 181 ( 182 0x44, 0x39, 0x79, 0x6b, 0xb5, 0x4c, 0x50, 0x23, 183 0xb6, 0x75, 0xea, 0x5b, 0x12, 0x4f, 0x9f, 0x30, 184 0xb8, 0xa2, 0xc0, 0x3d, 0xdf, 0xdc, 0x4d, 0x02, 185 0xa0, 0x8c, 0x9b, 0x33, 0x4a, 0xf6, 0x4a, 0x4c, 186 0x05, 0xc6, 0xfa, 0x34, 0x39, 0x58, 0xd8, 0x55, 187 0x7d, 0x99, 0x58, 0x3a, 0xe1, 0x38, 0xc9, 0x2e, 188 0x81, 0x15, 0x03, 0x66, 189 ) 190) 191# fmt: on 192DEFAULT_DRIVER_KEYS = { 193 "intel": RSS_KEY_INTEL, 194 "mlx": RSS_KEY_MLX, 195 "i40e": RSS_KEY_I40E, 196} 197 198 199def rss_key(value): 200 if value in DEFAULT_DRIVER_KEYS: 201 return DEFAULT_DRIVER_KEYS[value] 202 try: 203 key = binascii.unhexlify(value) 204 if len(key) not in (40, 52): 205 raise argparse.ArgumentTypeError("The key must be 40 or 52 bytes long") 206 return key 207 except (TypeError, ValueError) as e: 208 raise argparse.ArgumentTypeError(str(e)) from e 209 210 211def port_range(value): 212 try: 213 if "-" in value: 214 start, stop = value.split("-") 215 res = tuple(range(int(start), int(stop))) 216 else: 217 res = (int(value),) 218 return res or NO_PORT 219 except ValueError as e: 220 raise argparse.ArgumentTypeError(str(e)) from e 221 222 223def positive_int(value): 224 try: 225 i = int(value) 226 if i <= 0: 227 raise argparse.ArgumentTypeError("must be strictly positive") 228 return i 229 except ValueError as e: 230 raise argparse.ArgumentTypeError(str(e)) from e 231 232 233def power_of_two(value): 234 i = positive_int(value) 235 if i & (i - 1) != 0: 236 raise argparse.ArgumentTypeError("must be a power of two") 237 return i 238 239 240def parse_args(): 241 parser = argparse.ArgumentParser(description=__doc__) 242 243 parser.add_argument( 244 "rx_queues", 245 metavar="RX_QUEUES", 246 type=positive_int, 247 help=""" 248 The number of RX queues to fill. 249 """, 250 ) 251 parser.add_argument( 252 "ip_src", 253 metavar="SRC", 254 type=ipaddress.ip_network, 255 help=""" 256 The source IP network/address. 257 """, 258 ) 259 parser.add_argument( 260 "ip_dst", 261 metavar="DST", 262 type=ipaddress.ip_network, 263 help=""" 264 The destination IP network/address. 265 """, 266 ) 267 parser.add_argument( 268 "-s", 269 "--sport-range", 270 type=port_range, 271 default=NO_PORT, 272 help=""" 273 The layer 4 (TCP/UDP) source port range. 274 Can be a single fixed value or a range <start>-<end>. 275 """, 276 ) 277 parser.add_argument( 278 "-d", 279 "--dport-range", 280 type=port_range, 281 default=NO_PORT, 282 help=""" 283 The layer 4 (TCP/UDP) destination port range. 284 Can be a single fixed value or a range <start>-<end>. 285 """, 286 ) 287 parser.add_argument( 288 "-r", 289 "--check-reverse-traffic", 290 action="store_true", 291 help=""" 292 The reversed traffic (source <-> dest) should also be evenly balanced 293 in the queues. 294 """, 295 ) 296 parser.add_argument( 297 "-k", 298 "--rss-key", 299 default=RSS_KEY_INTEL, 300 type=rss_key, 301 help=""" 302 The random 40-bytes key used to compute the RSS hash. This option 303 supports either a well-known name or the hex value of the key 304 (well-known names: "intel", "mlx", default: "intel"). 305 """, 306 ) 307 parser.add_argument( 308 "-t", 309 "--reta-size", 310 default=128, 311 type=power_of_two, 312 help=""" 313 Size of the redirection table or "RETA" (default: 128). 314 """, 315 ) 316 parser.add_argument( 317 "-a", 318 "--all-flows", 319 action="store_true", 320 help=""" 321 Output ALL flows that can be created based on source and destination 322 address/port ranges along their matched queue number. ATTENTION: this 323 option can produce very long outputs depending on the address and port 324 range sizes. 325 """, 326 ) 327 parser.add_argument( 328 "-j", 329 "--json", 330 action="store_true", 331 help=""" 332 Output in parseable JSON format. 333 """, 334 ) 335 336 args = parser.parse_args() 337 338 if args.ip_src.version != args.ip_dst.version: 339 parser.error( 340 f"{args.ip_src} and {args.ip_dst} don't have the same protocol version" 341 ) 342 if args.reta_size < args.rx_queues: 343 parser.error("RETA_SIZE must be greater than or equal to RX_QUEUES") 344 345 return args 346 347 348def main(): 349 args = parse_args() 350 use_l4_port = args.sport_range != NO_PORT or args.dport_range != NO_PORT 351 352 algo = RSSAlgo( 353 queues_count=args.rx_queues, 354 key=args.rss_key, 355 reta_size=args.reta_size, 356 use_l4_port=use_l4_port, 357 ) 358 template = TrafficTemplate( 359 args.ip_src, 360 args.ip_dst, 361 args.sport_range, 362 args.dport_range, 363 ) 364 365 results = balanced_traffic( 366 algo, template, args.check_reverse_traffic, args.all_flows 367 ) 368 369 if args.json: 370 flows = [] 371 for q, qr, pkt in results: 372 flows.append( 373 { 374 "queue": q, 375 "queue_reverse": qr, 376 "src_ip": str(pkt.ip_src), 377 "dst_ip": str(pkt.ip_dst), 378 "src_port": pkt.l4_sport, 379 "dst_port": pkt.l4_dport, 380 } 381 ) 382 print(json.dumps(flows, indent=2)) 383 return 384 385 if use_l4_port: 386 header = ["SRC_IP", "SPORT", "DST_IP", "DPORT", "QUEUE"] 387 else: 388 header = ["SRC_IP", "DST_IP", "QUEUE"] 389 if args.check_reverse_traffic: 390 header.append("QUEUE_REVERSE") 391 392 rows = [tuple(header)] 393 widths = [len(h) for h in header] 394 395 for q, qr, pkt in results: 396 if use_l4_port: 397 row = [pkt.ip_src, pkt.l4_sport, pkt.ip_dst, pkt.l4_dport, q] 398 else: 399 row = [pkt.ip_src, pkt.ip_dst, q] 400 if args.check_reverse_traffic: 401 row.append(qr) 402 cells = [] 403 for i, r in enumerate(row): 404 r = str(r) 405 if len(r) > widths[i]: 406 widths[i] = len(r) 407 cells.append(r) 408 rows.append(tuple(cells)) 409 410 fmt = [f"%-{w}s" for w in widths] 411 fmt[-1] = "%s" # avoid trailing whitespace 412 fmt = " ".join(fmt) 413 for row in rows: 414 print(fmt % row) 415 416 417if __name__ == "__main__": 418 main() 419