1# SPDX-License-Identifier: BSD-3-Clause 2# Copyright(c) 2023 PANTHEON.tech s.r.o. 3 4"""SSH remote session.""" 5 6import socket 7import traceback 8from pathlib import PurePath 9 10from fabric import Connection # type: ignore[import] 11from invoke.exceptions import ( # type: ignore[import] 12 CommandTimedOut, 13 ThreadException, 14 UnexpectedExit, 15) 16from paramiko.ssh_exception import ( # type: ignore[import] 17 AuthenticationException, 18 BadHostKeyException, 19 NoValidConnectionsError, 20 SSHException, 21) 22 23from framework.exception import SSHConnectionError, SSHSessionDeadError, SSHTimeoutError 24 25from .remote_session import CommandResult, RemoteSession 26 27 28class SSHSession(RemoteSession): 29 """A persistent SSH connection to a remote Node. 30 31 The connection is implemented with 32 `the Fabric Python library <https://docs.fabfile.org/en/latest/>`_. 33 34 Attributes: 35 session: The underlying Fabric SSH connection. 36 37 Raises: 38 SSHConnectionError: The connection cannot be established. 39 """ 40 41 session: Connection 42 43 def _connect(self) -> None: 44 errors = [] 45 retry_attempts = 10 46 login_timeout = 20 if self.port else 10 47 for retry_attempt in range(retry_attempts): 48 try: 49 self.session = Connection( 50 self.ip, 51 user=self.username, 52 port=self.port, 53 connect_kwargs={"password": self.password}, 54 connect_timeout=login_timeout, 55 ) 56 self.session.open() 57 58 except (ValueError, BadHostKeyException, AuthenticationException) as e: 59 self._logger.exception(e) 60 raise SSHConnectionError(self.hostname) from e 61 62 except (NoValidConnectionsError, socket.error, SSHException) as e: 63 self._logger.debug(traceback.format_exc()) 64 self._logger.warning(e) 65 66 error = repr(e) 67 if error not in errors: 68 errors.append(error) 69 70 self._logger.info(f"Retrying connection: retry number {retry_attempt + 1}.") 71 72 else: 73 break 74 else: 75 raise SSHConnectionError(self.hostname, errors) 76 77 def is_alive(self) -> bool: 78 """Overrides :meth:`~.remote_session.RemoteSession.is_alive`.""" 79 return self.session.is_connected 80 81 def _send_command(self, command: str, timeout: float, env: dict | None) -> CommandResult: 82 """Send a command and return the result of the execution. 83 84 Args: 85 command: The command to execute. 86 timeout: Wait at most this long in seconds for the command execution to complete. 87 env: Extra environment variables that will be used in command execution. 88 89 Raises: 90 SSHSessionDeadError: The session died while executing the command. 91 SSHTimeoutError: The command execution timed out. 92 """ 93 try: 94 output = self.session.run(command, env=env, warn=True, hide=True, timeout=timeout) 95 96 except (UnexpectedExit, ThreadException) as e: 97 self._logger.exception(e) 98 raise SSHSessionDeadError(self.hostname) from e 99 100 except CommandTimedOut as e: 101 self._logger.exception(e) 102 raise SSHTimeoutError(command) from e 103 104 return CommandResult(self.name, command, output.stdout, output.stderr, output.return_code) 105 106 def copy_from( 107 self, 108 source_file: str | PurePath, 109 destination_file: str | PurePath, 110 ) -> None: 111 """Overrides :meth:`~.remote_session.RemoteSession.copy_from`.""" 112 self.session.get(str(destination_file), str(source_file)) 113 114 def copy_to( 115 self, 116 source_file: str | PurePath, 117 destination_file: str | PurePath, 118 ) -> None: 119 """Overrides :meth:`~.remote_session.RemoteSession.copy_to`.""" 120 self.session.put(str(source_file), str(destination_file)) 121 122 def _close(self, force: bool = False) -> None: 123 self.session.close() 124