xref: /dpdk/dts/framework/remote_session/interactive_shell.py (revision cfa443351ef581b7189467842ca102ab710cb7d2)
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright(c) 2023 University of New Hampshire
3# Copyright(c) 2024 Arm Limited
4
5"""Common functionality for interactive shell handling.
6
7The base class, :class:`InteractiveShell`, is meant to be extended by subclasses that contain
8functionality specific to that shell type. These subclasses will often modify things like
9the prompt to expect or the arguments to pass into the application, but still utilize
10the same method for sending a command and collecting output. How this output is handled however
11is often application specific. If an application needs elevated privileges to start it is expected
12that the method for gaining those privileges is provided when initializing the class.
13
14The :option:`--timeout` command line argument and the :envvar:`DTS_TIMEOUT`
15environment variable configure the timeout of getting the output from command execution.
16"""
17
18from abc import ABC
19from pathlib import PurePath
20from typing import ClassVar
21
22from paramiko import Channel, channel  # type: ignore[import-untyped]
23
24from framework.logger import DTSLogger
25from framework.params import Params
26from framework.settings import SETTINGS
27from framework.testbed_model.node import Node
28
29
30class InteractiveShell(ABC):
31    """The base class for managing interactive shells.
32
33    This class shouldn't be instantiated directly, but instead be extended. It contains
34    methods for starting interactive shells as well as sending commands to these shells
35    and collecting input until reaching a certain prompt. All interactive applications
36    will use the same SSH connection, but each will create their own channel on that
37    session.
38    """
39
40    _node: Node
41    _stdin: channel.ChannelStdinFile
42    _stdout: channel.ChannelFile
43    _ssh_channel: Channel
44    _logger: DTSLogger
45    _timeout: float
46    _app_params: Params
47    _privileged: bool
48    _real_path: PurePath
49
50    #: Prompt to expect at the end of output when sending a command.
51    #: This is often overridden by subclasses.
52    _default_prompt: ClassVar[str] = ""
53
54    #: Extra characters to add to the end of every command
55    #: before sending them. This is often overridden by subclasses and is
56    #: most commonly an additional newline character.
57    _command_extra_chars: ClassVar[str] = ""
58
59    #: Path to the executable to start the interactive application.
60    path: ClassVar[PurePath]
61
62    def __init__(
63        self,
64        node: Node,
65        privileged: bool = False,
66        timeout: float = SETTINGS.timeout,
67        start_on_init: bool = True,
68        app_params: Params = Params(),
69    ) -> None:
70        """Create an SSH channel during initialization.
71
72        Args:
73            node: The node on which to run start the interactive shell.
74            privileged: Enables the shell to run as superuser.
75            timeout: The timeout used for the SSH channel that is dedicated to this interactive
76                shell. This timeout is for collecting output, so if reading from the buffer
77                and no output is gathered within the timeout, an exception is thrown.
78            start_on_init: Start interactive shell automatically after object initialisation.
79            app_params: The command line parameters to be passed to the application on startup.
80        """
81        self._node = node
82        self._logger = node._logger
83        self._app_params = app_params
84        self._privileged = privileged
85        self._timeout = timeout
86        # Ensure path is properly formatted for the host
87        self._update_real_path(self.path)
88
89        if start_on_init:
90            self.start_application()
91
92    def _setup_ssh_channel(self):
93        self._ssh_channel = self._node.main_session.interactive_session.session.invoke_shell()
94        self._stdin = self._ssh_channel.makefile_stdin("w")
95        self._stdout = self._ssh_channel.makefile("r")
96        self._ssh_channel.settimeout(self._timeout)
97        self._ssh_channel.set_combine_stderr(True)  # combines stdout and stderr streams
98
99    def _make_start_command(self) -> str:
100        """Makes the command that starts the interactive shell."""
101        start_command = f"{self._real_path} {self._app_params or ''}"
102        if self._privileged:
103            start_command = self._node.main_session._get_privileged_command(start_command)
104        return start_command
105
106    def start_application(self) -> None:
107        """Starts a new interactive application based on the path to the app.
108
109        This method is often overridden by subclasses as their process for
110        starting may look different.
111        """
112        self._setup_ssh_channel()
113        self.send_command(self._make_start_command())
114
115    def send_command(
116        self, command: str, prompt: str | None = None, skip_first_line: bool = False
117    ) -> str:
118        """Send `command` and get all output before the expected ending string.
119
120        Lines that expect input are not included in the stdout buffer, so they cannot
121        be used for expect.
122
123        Example:
124            If you were prompted to log into something with a username and password,
125            you cannot expect ``username:`` because it won't yet be in the stdout buffer.
126            A workaround for this could be consuming an extra newline character to force
127            the current `prompt` into the stdout buffer.
128
129        Args:
130            command: The command to send.
131            prompt: After sending the command, `send_command` will be expecting this string.
132                If :data:`None`, will use the class's default prompt.
133            skip_first_line: Skip the first line when capturing the output.
134
135        Returns:
136            All output in the buffer before expected string.
137        """
138        self._logger.info(f"Sending: '{command}'")
139        if prompt is None:
140            prompt = self._default_prompt
141        self._stdin.write(f"{command}{self._command_extra_chars}\n")
142        self._stdin.flush()
143        out: str = ""
144        for line in self._stdout:
145            if skip_first_line:
146                skip_first_line = False
147                continue
148            if prompt in line and not line.rstrip().endswith(
149                command.rstrip()
150            ):  # ignore line that sent command
151                break
152            out += line
153        self._logger.debug(f"Got output: {out}")
154        return out
155
156    def close(self) -> None:
157        """Properly free all resources."""
158        self._stdin.close()
159        self._ssh_channel.close()
160
161    def __del__(self) -> None:
162        """Make sure the session is properly closed before deleting the object."""
163        self.close()
164
165    def _update_real_path(self, path: PurePath) -> None:
166        """Updates the interactive shell's real path used at command line."""
167        self._real_path = self._node.main_session.join_remote_path(path)
168