1# SPDX-License-Identifier: BSD-3-Clause 2# Copyright(c) 2023 PANTHEON.tech s.r.o. 3 4"""SSH remote session.""" 5 6import socket 7import traceback 8from pathlib import Path, PurePath 9 10from fabric import Connection # type: ignore[import-untyped] 11from invoke.exceptions import ( # type: ignore[import-untyped] 12 CommandTimedOut, 13 ThreadException, 14 UnexpectedExit, 15) 16from paramiko.ssh_exception import ( # type: ignore[import-untyped] 17 AuthenticationException, 18 BadHostKeyException, 19 NoValidConnectionsError, 20 SSHException, 21) 22 23from framework.exception import SSHConnectionError, SSHSessionDeadError, SSHTimeoutError 24 25from .remote_session import CommandResult, RemoteSession 26 27 28class SSHSession(RemoteSession): 29 """A persistent SSH connection to a remote Node. 30 31 The connection is implemented with 32 `the Fabric Python library <https://docs.fabfile.org/en/latest/>`_. 33 34 Attributes: 35 session: The underlying Fabric SSH connection. 36 37 Raises: 38 SSHConnectionError: The connection cannot be established. 39 """ 40 41 session: Connection 42 43 def _connect(self) -> None: 44 errors = [] 45 retry_attempts = 10 46 login_timeout = 20 if self.port else 10 47 for retry_attempt in range(retry_attempts): 48 try: 49 self.session = Connection( 50 self.ip, 51 user=self.username, 52 port=self.port, 53 connect_kwargs={"password": self.password}, 54 connect_timeout=login_timeout, 55 ) 56 self.session.open() 57 58 except (ValueError, BadHostKeyException, AuthenticationException) as e: 59 self._logger.exception(e) 60 raise SSHConnectionError(self.hostname) from e 61 62 except (NoValidConnectionsError, socket.error, SSHException) as e: 63 self._logger.debug(traceback.format_exc()) 64 self._logger.warning(e) 65 66 error = repr(e) 67 if error not in errors: 68 errors.append(error) 69 70 self._logger.info(f"Retrying connection: retry number {retry_attempt + 1}.") 71 72 else: 73 break 74 else: 75 raise SSHConnectionError(self.hostname, errors) 76 77 def _send_command(self, command: str, timeout: float, env: dict | None) -> CommandResult: 78 """Send a command and return the result of the execution. 79 80 Args: 81 command: The command to execute. 82 timeout: Wait at most this long in seconds for the command execution to complete. 83 env: Extra environment variables that will be used in command execution. 84 85 Raises: 86 SSHSessionDeadError: The session died while executing the command. 87 SSHTimeoutError: The command execution timed out. 88 """ 89 try: 90 output = self.session.run(command, env=env, warn=True, hide=True, timeout=timeout) 91 92 except (UnexpectedExit, ThreadException) as e: 93 self._logger.exception(e) 94 raise SSHSessionDeadError(self.hostname) from e 95 96 except CommandTimedOut as e: 97 self._logger.exception(e) 98 raise SSHTimeoutError(command) from e 99 100 return CommandResult(self.name, command, output.stdout, output.stderr, output.return_code) 101 102 def is_alive(self) -> bool: 103 """Overrides :meth:`~.remote_session.RemoteSession.is_alive`.""" 104 return self.session.is_connected 105 106 def copy_from(self, source_file: str | PurePath, destination_dir: str | Path) -> None: 107 """Overrides :meth:`~.remote_session.RemoteSession.copy_from`.""" 108 self.session.get(str(source_file), str(destination_dir)) 109 110 def copy_to(self, source_file: str | Path, destination_dir: str | PurePath) -> None: 111 """Overrides :meth:`~.remote_session.RemoteSession.copy_to`.""" 112 self.session.put(str(source_file), str(destination_dir)) 113 114 def close(self) -> None: 115 """Overrides :meth:`~.remote_session.RemoteSession.close`.""" 116 self.session.close() 117