xref: /dpdk/dts/framework/remote_session/ssh_session.py (revision e9fd1ebf981f361844aea9ec94e17f4bda5e1479)
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright(c) 2023 PANTHEON.tech s.r.o.
3
4"""SSH remote session."""
5
6import socket
7import traceback
8from pathlib import PurePath
9
10from fabric import Connection  # type: ignore[import]
11from invoke.exceptions import (  # type: ignore[import]
12    CommandTimedOut,
13    ThreadException,
14    UnexpectedExit,
15)
16from paramiko.ssh_exception import (  # type: ignore[import]
17    AuthenticationException,
18    BadHostKeyException,
19    NoValidConnectionsError,
20    SSHException,
21)
22
23from framework.exception import SSHConnectionError, SSHSessionDeadError, SSHTimeoutError
24
25from .remote_session import CommandResult, RemoteSession
26
27
28class SSHSession(RemoteSession):
29    """A persistent SSH connection to a remote Node.
30
31    The connection is implemented with
32    `the Fabric Python library <https://docs.fabfile.org/en/latest/>`_.
33
34    Attributes:
35        session: The underlying Fabric SSH connection.
36
37    Raises:
38        SSHConnectionError: The connection cannot be established.
39    """
40
41    session: Connection
42
43    def _connect(self) -> None:
44        errors = []
45        retry_attempts = 10
46        login_timeout = 20 if self.port else 10
47        for retry_attempt in range(retry_attempts):
48            try:
49                self.session = Connection(
50                    self.ip,
51                    user=self.username,
52                    port=self.port,
53                    connect_kwargs={"password": self.password},
54                    connect_timeout=login_timeout,
55                )
56                self.session.open()
57
58            except (ValueError, BadHostKeyException, AuthenticationException) as e:
59                self._logger.exception(e)
60                raise SSHConnectionError(self.hostname) from e
61
62            except (NoValidConnectionsError, socket.error, SSHException) as e:
63                self._logger.debug(traceback.format_exc())
64                self._logger.warning(e)
65
66                error = repr(e)
67                if error not in errors:
68                    errors.append(error)
69
70                self._logger.info(f"Retrying connection: retry number {retry_attempt + 1}.")
71
72            else:
73                break
74        else:
75            raise SSHConnectionError(self.hostname, errors)
76
77    def is_alive(self) -> bool:
78        """Overrides :meth:`~.remote_session.RemoteSession.is_alive`."""
79        return self.session.is_connected
80
81    def _send_command(self, command: str, timeout: float, env: dict | None) -> CommandResult:
82        """Send a command and return the result of the execution.
83
84        Args:
85            command: The command to execute.
86            timeout: Wait at most this long in seconds for the command execution to complete.
87            env: Extra environment variables that will be used in command execution.
88
89        Raises:
90            SSHSessionDeadError: The session died while executing the command.
91            SSHTimeoutError: The command execution timed out.
92        """
93        try:
94            output = self.session.run(command, env=env, warn=True, hide=True, timeout=timeout)
95
96        except (UnexpectedExit, ThreadException) as e:
97            self._logger.exception(e)
98            raise SSHSessionDeadError(self.hostname) from e
99
100        except CommandTimedOut as e:
101            self._logger.exception(e)
102            raise SSHTimeoutError(command) from e
103
104        return CommandResult(self.name, command, output.stdout, output.stderr, output.return_code)
105
106    def copy_from(
107        self,
108        source_file: str | PurePath,
109        destination_file: str | PurePath,
110    ) -> None:
111        """Overrides :meth:`~.remote_session.RemoteSession.copy_from`."""
112        self.session.get(str(destination_file), str(source_file))
113
114    def copy_to(
115        self,
116        source_file: str | PurePath,
117        destination_file: str | PurePath,
118    ) -> None:
119        """Overrides :meth:`~.remote_session.RemoteSession.copy_to`."""
120        self.session.put(str(source_file), str(destination_file))
121
122    def _close(self, force: bool = False) -> None:
123        self.session.close()
124