xref: /dpdk/dts/framework/utils.py (revision 80158fd411bb06d1a1c22f56fd387ecc49eaf15d)
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright(c) 2010-2014 Intel Corporation
3# Copyright(c) 2022-2023 PANTHEON.tech s.r.o.
4# Copyright(c) 2022-2023 University of New Hampshire
5# Copyright(c) 2024 Arm Limited
6
7"""Various utility classes and functions.
8
9These are used in multiple modules across the framework. They're here because
10they provide some non-specific functionality, greatly simplify imports or just don't
11fit elsewhere.
12
13Attributes:
14    REGEX_FOR_PCI_ADDRESS: The regex representing a PCI address, e.g. ``0000:00:08.0``.
15"""
16
17import atexit
18import fnmatch
19import json
20import os
21import random
22import subprocess
23import tarfile
24from enum import Enum, Flag
25from pathlib import Path
26from subprocess import SubprocessError
27from typing import Any, Callable
28
29from scapy.layers.inet import IP, TCP, UDP, Ether  # type: ignore[import-untyped]
30from scapy.packet import Packet  # type: ignore[import-untyped]
31
32from .exception import ConfigurationError, InternalError
33
34REGEX_FOR_PCI_ADDRESS: str = "/[0-9a-fA-F]{4}:[0-9a-fA-F]{2}:[0-9a-fA-F]{2}.[0-9]{1}/"
35_REGEX_FOR_COLON_OR_HYPHEN_SEP_MAC: str = r"(?:[\da-fA-F]{2}[:-]){5}[\da-fA-F]{2}"
36_REGEX_FOR_DOT_SEP_MAC: str = r"(?:[\da-fA-F]{4}.){2}[\da-fA-F]{4}"
37REGEX_FOR_MAC_ADDRESS: str = rf"{_REGEX_FOR_COLON_OR_HYPHEN_SEP_MAC}|{_REGEX_FOR_DOT_SEP_MAC}"
38REGEX_FOR_BASE64_ENCODING: str = "[-a-zA-Z0-9+\\/]*={0,3}"
39
40
41def expand_range(range_str: str) -> list[int]:
42    """Process `range_str` into a list of integers.
43
44    There are two possible formats of `range_str`:
45
46        * ``n`` - a single integer,
47        * ``n-m`` - a range of integers.
48
49    The returned range includes both ``n`` and ``m``. Empty string returns an empty list.
50
51    Args:
52        range_str: The range to expand.
53
54    Returns:
55        All the numbers from the range.
56    """
57    expanded_range: list[int] = []
58    if range_str:
59        range_boundaries = range_str.split("-")
60        # will throw an exception when items in range_boundaries can't be converted,
61        # serving as type check
62        expanded_range.extend(range(int(range_boundaries[0]), int(range_boundaries[-1]) + 1))
63
64    return expanded_range
65
66
67def get_packet_summaries(packets: list[Packet]) -> str:
68    """Format a string summary from `packets`.
69
70    Args:
71        packets: The packets to format.
72
73    Returns:
74        The summary of `packets`.
75    """
76    if len(packets) == 1:
77        packet_summaries = packets[0].summary()
78    else:
79        packet_summaries = json.dumps(list(map(lambda pkt: pkt.summary(), packets)), indent=4)
80    return f"Packet contents: \n{packet_summaries}"
81
82
83def get_commit_id(rev_id: str) -> str:
84    """Given a Git revision ID, return the corresponding commit ID.
85
86    Args:
87        rev_id: The Git revision ID.
88
89    Raises:
90        ConfigurationError: The ``git rev-parse`` command failed, suggesting
91            an invalid or ambiguous revision ID was supplied.
92    """
93    result = subprocess.run(
94        ["git", "rev-parse", "--verify", rev_id],
95        text=True,
96        capture_output=True,
97    )
98    if result.returncode != 0:
99        raise ConfigurationError(
100            f"{rev_id} is not a valid git reference.\n"
101            f"Command: {result.args}\n"
102            f"Stdout: {result.stdout}\n"
103            f"Stderr: {result.stderr}"
104        )
105    return result.stdout.strip()
106
107
108class StrEnum(Enum):
109    """Enum with members stored as strings."""
110
111    @staticmethod
112    def _generate_next_value_(name: str, start: int, count: int, last_values: object) -> str:
113        return name
114
115    def __str__(self) -> str:
116        """The string representation is the name of the member."""
117        return self.name
118
119
120class MesonArgs:
121    """Aggregate the arguments needed to build DPDK."""
122
123    _default_library: str
124
125    def __init__(self, default_library: str | None = None, **dpdk_args: str | bool):
126        """Initialize the meson arguments.
127
128        Args:
129            default_library: The default library type, Meson supports ``shared``, ``static`` and
130                ``both``. Defaults to :data:`None`, in which case the argument won't be used.
131            dpdk_args: The arguments found in ``meson_options.txt`` in root DPDK directory.
132                Do not use ``-D`` with them.
133
134        Example:
135            ::
136
137                meson_args = MesonArgs(enable_kmods=True).
138        """
139        self._default_library = f"--default-library={default_library}" if default_library else ""
140        self._dpdk_args = " ".join(
141            (
142                f"-D{dpdk_arg_name}={dpdk_arg_value}"
143                for dpdk_arg_name, dpdk_arg_value in dpdk_args.items()
144            )
145        )
146
147    def __str__(self) -> str:
148        """The actual args."""
149        return " ".join(f"{self._default_library} {self._dpdk_args}".split())
150
151
152class TarCompressionFormat(StrEnum):
153    """Compression formats that tar can use.
154
155    Enum names are the shell compression commands
156    and Enum values are the associated file extensions.
157
158    The 'none' member represents no compression, only archiving with tar.
159    Its value is set to 'tar' to indicate that the file is an uncompressed tar archive.
160    """
161
162    none = "tar"
163    gzip = "gz"
164    compress = "Z"
165    bzip2 = "bz2"
166    lzip = "lz"
167    lzma = "lzma"
168    lzop = "lzo"
169    xz = "xz"
170    zstd = "zst"
171
172    @property
173    def extension(self):
174        """Return the extension associated with the compression format.
175
176        If the compression format is 'none', the extension will be in the format 'tar'.
177        For other compression formats, the extension will be in the format
178        'tar.{compression format}'.
179        """
180        return f"{self.value}" if self == self.none else f"{self.none.value}.{self.value}"
181
182
183class DPDKGitTarball:
184    """Compressed tarball of DPDK from the repository.
185
186    The class supports the :class:`os.PathLike` protocol,
187    which is used to get the Path of the tarball::
188
189        from pathlib import Path
190        tarball = DPDKGitTarball("HEAD", "output")
191        tarball_path = Path(tarball)
192    """
193
194    _git_ref: str
195    _tar_compression_format: TarCompressionFormat
196    _tarball_dir: Path
197    _tarball_name: str
198    _tarball_path: Path | None
199
200    def __init__(
201        self,
202        git_ref: str,
203        output_dir: str,
204        tar_compression_format: TarCompressionFormat = TarCompressionFormat.xz,
205    ):
206        """Create the tarball during initialization.
207
208        The DPDK version is specified with `git_ref`. The tarball will be compressed with
209        `tar_compression_format`, which must be supported by the DTS execution environment.
210        The resulting tarball will be put into `output_dir`.
211
212        Args:
213            git_ref: A git commit ID, tag ID or tree ID.
214            output_dir: The directory where to put the resulting tarball.
215            tar_compression_format: The compression format to use.
216        """
217        self._git_ref = git_ref
218        self._tar_compression_format = tar_compression_format
219
220        self._tarball_dir = Path(output_dir, "tarball")
221
222        self._create_tarball_dir()
223
224        self._tarball_name = (
225            f"dpdk-tarball-{self._git_ref}.{self._tar_compression_format.extension}"
226        )
227        self._tarball_path = self._check_tarball_path()
228        if not self._tarball_path:
229            self._create_tarball()
230
231    def _create_tarball_dir(self) -> None:
232        os.makedirs(self._tarball_dir, exist_ok=True)
233
234    def _check_tarball_path(self) -> Path | None:
235        if self._tarball_name in os.listdir(self._tarball_dir):
236            return Path(self._tarball_dir, self._tarball_name)
237        return None
238
239    def _create_tarball(self) -> None:
240        self._tarball_path = Path(self._tarball_dir, self._tarball_name)
241
242        atexit.register(self._delete_tarball)
243
244        result = subprocess.run(
245            'git -C "$(git rev-parse --show-toplevel)" archive '
246            f'{self._git_ref} --prefix="dpdk-tarball-{self._git_ref + os.sep}" | '
247            f"{self._tar_compression_format} > {Path(self._tarball_path.absolute())}",
248            shell=True,
249            text=True,
250            capture_output=True,
251        )
252
253        if result.returncode != 0:
254            raise SubprocessError(
255                f"Git archive creation failed with exit code {result.returncode}.\n"
256                f"Command: {result.args}\n"
257                f"Stdout: {result.stdout}\n"
258                f"Stderr: {result.stderr}"
259            )
260
261        atexit.unregister(self._delete_tarball)
262
263    def _delete_tarball(self) -> None:
264        if self._tarball_path and os.path.exists(self._tarball_path):
265            os.remove(self._tarball_path)
266
267    def __fspath__(self) -> str:
268        """The os.PathLike protocol implementation."""
269        return str(self._tarball_path)
270
271
272def convert_to_list_of_string(value: Any | list[Any]) -> list[str]:
273    """Convert the input to the list of strings."""
274    return list(map(str, value) if isinstance(value, list) else str(value))
275
276
277def create_tarball(
278    dir_path: Path,
279    compress_format: TarCompressionFormat = TarCompressionFormat.none,
280    exclude: Any | list[Any] | None = None,
281) -> Path:
282    """Create a tarball from the contents of the specified directory.
283
284    This method creates a tarball containing all files and directories within `dir_path`.
285    The tarball will be saved in the directory of `dir_path` and will be named based on `dir_path`.
286
287    Args:
288        dir_path: The directory path.
289        compress_format: The compression format to use. Defaults to no compression.
290        exclude: Patterns for files or directories to exclude from the tarball.
291                These patterns are used with `fnmatch.fnmatch` to filter out files.
292
293    Returns:
294        The path to the created tarball.
295    """
296
297    def create_filter_function(exclude_patterns: str | list[str] | None) -> Callable | None:
298        """Create a filter function based on the provided exclude patterns.
299
300        Args:
301            exclude_patterns: Patterns for files or directories to exclude from the tarball.
302                These patterns are used with `fnmatch.fnmatch` to filter out files.
303
304        Returns:
305            The filter function that excludes files based on the patterns.
306        """
307        if exclude_patterns:
308            exclude_patterns = convert_to_list_of_string(exclude_patterns)
309
310            def filter_func(tarinfo: tarfile.TarInfo) -> tarfile.TarInfo | None:
311                file_name = os.path.basename(tarinfo.name)
312                if any(fnmatch.fnmatch(file_name, pattern) for pattern in exclude_patterns):
313                    return None
314                return tarinfo
315
316            return filter_func
317        return None
318
319    target_tarball_path = dir_path.with_suffix(f".{compress_format.extension}")
320    with tarfile.open(target_tarball_path, f"w:{compress_format.value}") as tar:
321        tar.add(dir_path, arcname=dir_path.name, filter=create_filter_function(exclude))
322
323    return target_tarball_path
324
325
326def extract_tarball(tar_path: str | Path):
327    """Extract the contents of a tarball.
328
329    The tarball will be extracted in the same path as `tar_path` parent path.
330
331    Args:
332        tar_path: The path to the tarball file to extract.
333    """
334    with tarfile.open(tar_path, "r") as tar:
335        tar.extractall(path=Path(tar_path).parent)
336
337
338class PacketProtocols(Flag):
339    """Flag specifying which protocols to use for packet generation."""
340
341    #:
342    IP = 1
343    #:
344    TCP = 2 | IP
345    #:
346    UDP = 4 | IP
347    #:
348    ALL = TCP | UDP
349
350
351def generate_random_packets(
352    number_of: int,
353    payload_size: int = 1500,
354    protocols: PacketProtocols = PacketProtocols.ALL,
355    ports_range: range = range(1024, 49152),
356    mtu: int = 1500,
357) -> list[Packet]:
358    """Generate a number of random packets.
359
360    The payload of the packets will consist of random bytes. If `payload_size` is too big, then the
361    maximum payload size allowed for the specific packet type is used. The size is calculated based
362    on the specified `mtu`, therefore it is essential that `mtu` is set correctly to match the MTU
363    of the port that will send out the generated packets.
364
365    If `protocols` has any L4 protocol enabled then all the packets are generated with any of
366    the specified L4 protocols chosen at random. If only :attr:`~PacketProtocols.IP` is set, then
367    only L3 packets are generated.
368
369    If L4 packets will be generated, then the TCP/UDP ports to be used will be chosen at random from
370    `ports_range`.
371
372    Args:
373        number_of: The number of packets to generate.
374        payload_size: The packet payload size to generate, capped based on `mtu`.
375        protocols: The protocols to use for the generated packets.
376        ports_range: The range of L4 port numbers to use. Used only if `protocols` has L4 protocols.
377        mtu: The MTU of the NIC port that will send out the generated packets.
378
379    Raises:
380        InternalError: If the `payload_size` is invalid.
381
382    Returns:
383        A list containing the randomly generated packets.
384    """
385    if payload_size < 0:
386        raise InternalError(f"An invalid payload_size of {payload_size} was given.")
387
388    l4_factories = []
389    if protocols & PacketProtocols.TCP:
390        l4_factories.append(TCP)
391    if protocols & PacketProtocols.UDP:
392        l4_factories.append(UDP)
393
394    def _make_packet() -> Packet:
395        packet = Ether()
396
397        if protocols & PacketProtocols.IP:
398            packet /= IP()
399
400        if len(l4_factories) > 0:
401            src_port, dst_port = random.choices(ports_range, k=2)
402            packet /= random.choice(l4_factories)(sport=src_port, dport=dst_port)
403
404        max_payload_size = mtu - len(packet)
405        usable_payload_size = payload_size if payload_size < max_payload_size else max_payload_size
406        return packet / random.randbytes(usable_payload_size)
407
408    return [_make_packet() for _ in range(number_of)]
409
410
411class MultiInheritanceBaseClass:
412    """A base class for classes utilizing multiple inheritance.
413
414    This class enables it's subclasses to support both single and multiple inheritance by acting as
415    a stopping point in the tree of calls to the constructors of superclasses. This class is able
416    to exist at the end of the Method Resolution Order (MRO) so that subclasses can call
417    :meth:`super.__init__` without repercussion.
418    """
419
420    def __init__(self, *args, **kwargs) -> None:
421        """Call the init method of :class:`object`."""
422        super().__init__()
423