xref: /dpdk/dts/framework/utils.py (revision b935bdc3da26ab86ec775dfad3aa63a1a61f5667)
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright(c) 2010-2014 Intel Corporation
3# Copyright(c) 2022-2023 PANTHEON.tech s.r.o.
4# Copyright(c) 2022-2023 University of New Hampshire
5# Copyright(c) 2024 Arm Limited
6
7"""Various utility classes and functions.
8
9These are used in multiple modules across the framework. They're here because
10they provide some non-specific functionality, greatly simplify imports or just don't
11fit elsewhere.
12
13Attributes:
14    REGEX_FOR_PCI_ADDRESS: The regex representing a PCI address, e.g. ``0000:00:08.0``.
15"""
16
17import fnmatch
18import json
19import os
20import random
21import tarfile
22from enum import Enum, Flag
23from pathlib import Path
24from typing import Any, Callable
25
26from scapy.layers.inet import IP, TCP, UDP, Ether  # type: ignore[import-untyped]
27from scapy.packet import Packet  # type: ignore[import-untyped]
28
29from .exception import InternalError
30
31REGEX_FOR_PCI_ADDRESS: str = r"[0-9a-fA-F]{4}:[0-9a-fA-F]{2}:[0-9a-fA-F]{2}.[0-9]{1}"
32_REGEX_FOR_COLON_OR_HYPHEN_SEP_MAC: str = r"(?:[\da-fA-F]{2}[:-]){5}[\da-fA-F]{2}"
33_REGEX_FOR_DOT_SEP_MAC: str = r"(?:[\da-fA-F]{4}.){2}[\da-fA-F]{4}"
34REGEX_FOR_MAC_ADDRESS: str = rf"{_REGEX_FOR_COLON_OR_HYPHEN_SEP_MAC}|{_REGEX_FOR_DOT_SEP_MAC}"
35REGEX_FOR_BASE64_ENCODING: str = "[-a-zA-Z0-9+\\/]*={0,3}"
36
37
38def expand_range(range_str: str) -> list[int]:
39    """Process `range_str` into a list of integers.
40
41    There are two possible formats of `range_str`:
42
43        * ``n`` - a single integer,
44        * ``n-m`` - a range of integers.
45
46    The returned range includes both ``n`` and ``m``. Empty string returns an empty list.
47
48    Args:
49        range_str: The range to expand.
50
51    Returns:
52        All the numbers from the range.
53    """
54    expanded_range: list[int] = []
55    if range_str:
56        range_boundaries = range_str.split("-")
57        # will throw an exception when items in range_boundaries can't be converted,
58        # serving as type check
59        expanded_range.extend(range(int(range_boundaries[0]), int(range_boundaries[-1]) + 1))
60
61    return expanded_range
62
63
64def get_packet_summaries(packets: list[Packet]) -> str:
65    """Format a string summary from `packets`.
66
67    Args:
68        packets: The packets to format.
69
70    Returns:
71        The summary of `packets`.
72    """
73    if len(packets) == 1:
74        packet_summaries = packets[0].summary()
75    else:
76        packet_summaries = json.dumps(list(map(lambda pkt: pkt.summary(), packets)), indent=4)
77    return f"Packet contents: \n{packet_summaries}"
78
79
80class StrEnum(Enum):
81    """Enum with members stored as strings."""
82
83    @staticmethod
84    def _generate_next_value_(name: str, start: int, count: int, last_values: object) -> str:
85        return name
86
87    def __str__(self) -> str:
88        """The string representation is the name of the member."""
89        return self.name
90
91
92class MesonArgs:
93    """Aggregate the arguments needed to build DPDK."""
94
95    _default_library: str
96
97    def __init__(self, default_library: str | None = None, **dpdk_args: str | bool):
98        """Initialize the meson arguments.
99
100        Args:
101            default_library: The default library type, Meson supports ``shared``, ``static`` and
102                ``both``. Defaults to :data:`None`, in which case the argument won't be used.
103            dpdk_args: The arguments found in ``meson_options.txt`` in root DPDK directory.
104                Do not use ``-D`` with them.
105
106        Example:
107            ::
108
109                meson_args = MesonArgs(enable_kmods=True).
110        """
111        self._default_library = f"--default-library={default_library}" if default_library else ""
112        self._dpdk_args = " ".join(
113            (
114                f"-D{dpdk_arg_name}={dpdk_arg_value}"
115                for dpdk_arg_name, dpdk_arg_value in dpdk_args.items()
116            )
117        )
118
119    def __str__(self) -> str:
120        """The actual args."""
121        return " ".join(f"{self._default_library} {self._dpdk_args}".split())
122
123
124class TarCompressionFormat(StrEnum):
125    """Compression formats that tar can use.
126
127    Enum names are the shell compression commands
128    and Enum values are the associated file extensions.
129
130    The 'none' member represents no compression, only archiving with tar.
131    Its value is set to 'tar' to indicate that the file is an uncompressed tar archive.
132    """
133
134    none = "tar"
135    gzip = "gz"
136    compress = "Z"
137    bzip2 = "bz2"
138    lzip = "lz"
139    lzma = "lzma"
140    lzop = "lzo"
141    xz = "xz"
142    zstd = "zst"
143
144    @property
145    def extension(self):
146        """Return the extension associated with the compression format.
147
148        If the compression format is 'none', the extension will be in the format 'tar'.
149        For other compression formats, the extension will be in the format
150        'tar.{compression format}'.
151        """
152        return f"{self.value}" if self == self.none else f"{self.none.value}.{self.value}"
153
154
155def convert_to_list_of_string(value: Any | list[Any]) -> list[str]:
156    """Convert the input to the list of strings."""
157    return list(map(str, value) if isinstance(value, list) else str(value))
158
159
160def create_tarball(
161    dir_path: Path,
162    compress_format: TarCompressionFormat = TarCompressionFormat.none,
163    exclude: Any | list[Any] | None = None,
164) -> Path:
165    """Create a tarball from the contents of the specified directory.
166
167    This method creates a tarball containing all files and directories within `dir_path`.
168    The tarball will be saved in the directory of `dir_path` and will be named based on `dir_path`.
169
170    Args:
171        dir_path: The directory path.
172        compress_format: The compression format to use. Defaults to no compression.
173        exclude: Patterns for files or directories to exclude from the tarball.
174                These patterns are used with `fnmatch.fnmatch` to filter out files.
175
176    Returns:
177        The path to the created tarball.
178    """
179
180    def create_filter_function(exclude_patterns: str | list[str] | None) -> Callable | None:
181        """Create a filter function based on the provided exclude patterns.
182
183        Args:
184            exclude_patterns: Patterns for files or directories to exclude from the tarball.
185                These patterns are used with `fnmatch.fnmatch` to filter out files.
186
187        Returns:
188            The filter function that excludes files based on the patterns.
189        """
190        if exclude_patterns:
191            exclude_patterns = convert_to_list_of_string(exclude_patterns)
192
193            def filter_func(tarinfo: tarfile.TarInfo) -> tarfile.TarInfo | None:
194                file_name = os.path.basename(tarinfo.name)
195                if any(fnmatch.fnmatch(file_name, pattern) for pattern in exclude_patterns):
196                    return None
197                return tarinfo
198
199            return filter_func
200        return None
201
202    target_tarball_path = dir_path.with_suffix(f".{compress_format.extension}")
203    with tarfile.open(target_tarball_path, f"w:{compress_format.value}") as tar:
204        tar.add(dir_path, arcname=dir_path.name, filter=create_filter_function(exclude))
205
206    return target_tarball_path
207
208
209def extract_tarball(tar_path: str | Path):
210    """Extract the contents of a tarball.
211
212    The tarball will be extracted in the same path as `tar_path` parent path.
213
214    Args:
215        tar_path: The path to the tarball file to extract.
216    """
217    with tarfile.open(tar_path, "r") as tar:
218        tar.extractall(path=Path(tar_path).parent)
219
220
221class PacketProtocols(Flag):
222    """Flag specifying which protocols to use for packet generation."""
223
224    #:
225    IP = 1
226    #:
227    TCP = 2 | IP
228    #:
229    UDP = 4 | IP
230    #:
231    ALL = TCP | UDP
232
233
234def generate_random_packets(
235    number_of: int,
236    payload_size: int = 1500,
237    protocols: PacketProtocols = PacketProtocols.ALL,
238    ports_range: range = range(1024, 49152),
239    mtu: int = 1500,
240) -> list[Packet]:
241    """Generate a number of random packets.
242
243    The payload of the packets will consist of random bytes. If `payload_size` is too big, then the
244    maximum payload size allowed for the specific packet type is used. The size is calculated based
245    on the specified `mtu`, therefore it is essential that `mtu` is set correctly to match the MTU
246    of the port that will send out the generated packets.
247
248    If `protocols` has any L4 protocol enabled then all the packets are generated with any of
249    the specified L4 protocols chosen at random. If only :attr:`~PacketProtocols.IP` is set, then
250    only L3 packets are generated.
251
252    If L4 packets will be generated, then the TCP/UDP ports to be used will be chosen at random from
253    `ports_range`.
254
255    Args:
256        number_of: The number of packets to generate.
257        payload_size: The packet payload size to generate, capped based on `mtu`.
258        protocols: The protocols to use for the generated packets.
259        ports_range: The range of L4 port numbers to use. Used only if `protocols` has L4 protocols.
260        mtu: The MTU of the NIC port that will send out the generated packets.
261
262    Raises:
263        InternalError: If the `payload_size` is invalid.
264
265    Returns:
266        A list containing the randomly generated packets.
267    """
268    if payload_size < 0:
269        raise InternalError(f"An invalid payload_size of {payload_size} was given.")
270
271    l4_factories = []
272    if protocols & PacketProtocols.TCP:
273        l4_factories.append(TCP)
274    if protocols & PacketProtocols.UDP:
275        l4_factories.append(UDP)
276
277    def _make_packet() -> Packet:
278        packet = Ether()
279
280        if protocols & PacketProtocols.IP:
281            packet /= IP()
282
283        if len(l4_factories) > 0:
284            src_port, dst_port = random.choices(ports_range, k=2)
285            packet /= random.choice(l4_factories)(sport=src_port, dport=dst_port)
286
287        max_payload_size = mtu - len(packet)
288        usable_payload_size = payload_size if payload_size < max_payload_size else max_payload_size
289        return packet / random.randbytes(usable_payload_size)
290
291    return [_make_packet() for _ in range(number_of)]
292
293
294class MultiInheritanceBaseClass:
295    """A base class for classes utilizing multiple inheritance.
296
297    This class enables it's subclasses to support both single and multiple inheritance by acting as
298    a stopping point in the tree of calls to the constructors of superclasses. This class is able
299    to exist at the end of the Method Resolution Order (MRO) so that subclasses can call
300    :meth:`super.__init__` without repercussion.
301    """
302
303    def __init__(self, *args, **kwargs) -> None:
304        """Call the init method of :class:`object`."""
305        super().__init__()
306
307
308def to_pascal_case(text: str) -> str:
309    """Convert `text` from snake_case to PascalCase."""
310    return "".join([seg.capitalize() for seg in text.split("_")])
311