xref: /dpdk/dts/framework/remote_session/ssh_session.py (revision 441c5fbf939b55f635d42ad9b7dcc3741e1c2a7c)
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright(c) 2023 PANTHEON.tech s.r.o.
3
4"""SSH remote session."""
5
6import socket
7import traceback
8from pathlib import Path, PurePath
9
10from fabric import Connection  # type: ignore[import-untyped]
11from invoke.exceptions import (  # type: ignore[import-untyped]
12    CommandTimedOut,
13    ThreadException,
14    UnexpectedExit,
15)
16from paramiko.ssh_exception import (  # type: ignore[import-untyped]
17    AuthenticationException,
18    BadHostKeyException,
19    NoValidConnectionsError,
20    SSHException,
21)
22
23from framework.exception import SSHConnectionError, SSHSessionDeadError, SSHTimeoutError
24
25from .remote_session import CommandResult, RemoteSession
26
27
28class SSHSession(RemoteSession):
29    """A persistent SSH connection to a remote Node.
30
31    The connection is implemented with
32    `the Fabric Python library <https://docs.fabfile.org/en/latest/>`_.
33
34    Attributes:
35        session: The underlying Fabric SSH connection.
36
37    Raises:
38        SSHConnectionError: The connection cannot be established.
39    """
40
41    session: Connection
42
43    def _connect(self) -> None:
44        errors = []
45        retry_attempts = 10
46        login_timeout = 20 if self.port else 10
47        for retry_attempt in range(retry_attempts):
48            try:
49                self.session = Connection(
50                    self.ip,
51                    user=self.username,
52                    port=self.port,
53                    connect_kwargs={"password": self.password},
54                    connect_timeout=login_timeout,
55                )
56                self.session.open()
57
58            except (ValueError, BadHostKeyException, AuthenticationException) as e:
59                self._logger.exception(e)
60                raise SSHConnectionError(self.hostname) from e
61
62            except (NoValidConnectionsError, socket.error, SSHException) as e:
63                self._logger.debug(traceback.format_exc())
64                self._logger.warning(e)
65
66                error = repr(e)
67                if error not in errors:
68                    errors.append(error)
69
70                self._logger.info(f"Retrying connection: retry number {retry_attempt + 1}.")
71
72            else:
73                break
74        else:
75            raise SSHConnectionError(self.hostname, errors)
76
77    def _send_command(self, command: str, timeout: float, env: dict | None) -> CommandResult:
78        """Send a command and return the result of the execution.
79
80        Args:
81            command: The command to execute.
82            timeout: Wait at most this long in seconds for the command execution to complete.
83            env: Extra environment variables that will be used in command execution.
84
85        Raises:
86            SSHSessionDeadError: The session died while executing the command.
87            SSHTimeoutError: The command execution timed out.
88        """
89        try:
90            output = self.session.run(command, env=env, warn=True, hide=True, timeout=timeout)
91
92        except (UnexpectedExit, ThreadException) as e:
93            self._logger.exception(e)
94            raise SSHSessionDeadError(self.hostname) from e
95
96        except CommandTimedOut as e:
97            self._logger.exception(e)
98            raise SSHTimeoutError(command) from e
99
100        return CommandResult(self.name, command, output.stdout, output.stderr, output.return_code)
101
102    def is_alive(self) -> bool:
103        """Overrides :meth:`~.remote_session.RemoteSession.is_alive`."""
104        return self.session.is_connected
105
106    def copy_from(self, source_file: str | PurePath, destination_dir: str | Path) -> None:
107        """Overrides :meth:`~.remote_session.RemoteSession.copy_from`."""
108        self.session.get(str(source_file), str(destination_dir))
109
110    def copy_to(self, source_file: str | Path, destination_dir: str | PurePath) -> None:
111        """Overrides :meth:`~.remote_session.RemoteSession.copy_to`."""
112        self.session.put(str(source_file), str(destination_dir))
113
114    def close(self) -> None:
115        """Overrides :meth:`~.remote_session.RemoteSession.close`."""
116        self.session.close()
117