1# Copyright 2016 - The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import collections 16import os 17import re 18import shutil 19import tempfile 20import threading 21import time 22import uuid 23 24from acts import logger 25from acts.controllers.utils_lib import host_utils 26from acts.controllers.utils_lib.ssh import formatter 27from acts.libs.proc import job 28 29 30class Error(Exception): 31 """An error occurred during an ssh operation.""" 32 33 34class CommandError(Exception): 35 """An error occurred with the command. 36 37 Attributes: 38 result: The results of the ssh command that had the error. 39 """ 40 def __init__(self, result): 41 """ 42 Args: 43 result: The result of the ssh command that created the problem. 44 """ 45 self.result = result 46 47 def __str__(self): 48 return 'cmd: %s\nstdout: %s\nstderr: %s' % ( 49 self.result.command, self.result.stdout, self.result.stderr) 50 51 52_Tunnel = collections.namedtuple('_Tunnel', 53 ['local_port', 'remote_port', 'proc']) 54 55 56class SshConnection(object): 57 """Provides a connection to a remote machine through ssh. 58 59 Provides the ability to connect to a remote machine and execute a command 60 on it. The connection will try to establish a persistent connection When 61 a command is run. If the persistent connection fails it will attempt 62 to connect normally. 63 """ 64 @property 65 def socket_path(self): 66 """Returns: The os path to the master socket file.""" 67 return os.path.join(self._master_ssh_tempdir, 'socket') 68 69 def __init__(self, settings): 70 """ 71 Args: 72 settings: The ssh settings to use for this connection. 73 formatter: The object that will handle formatting ssh command 74 for use with the background job. 75 """ 76 self._settings = settings 77 self._formatter = formatter.SshFormatter() 78 self._lock = threading.Lock() 79 self._master_ssh_proc = None 80 self._master_ssh_tempdir = None 81 self._tunnels = list() 82 83 def log_line(msg): 84 return '[SshConnection | %s] %s' % (self._settings.hostname, msg) 85 86 self.log = logger.create_logger(log_line) 87 88 def __enter__(self): 89 return self 90 91 def __exit__(self, _, __, ___): 92 self.close() 93 94 def __del__(self): 95 self.close() 96 97 def setup_master_ssh(self, timeout_seconds=5): 98 """Sets up the master ssh connection. 99 100 Sets up the initial master ssh connection if it has not already been 101 started. 102 103 Args: 104 timeout_seconds: The time to wait for the master ssh connection to 105 be made. 106 107 Raises: 108 Error: When setting up the master ssh connection fails. 109 """ 110 with self._lock: 111 if self._master_ssh_proc is not None: 112 socket_path = self.socket_path 113 if (not os.path.exists(socket_path) 114 or self._master_ssh_proc.poll() is not None): 115 self.log.debug('Master ssh connection to %s is down.', 116 self._settings.hostname) 117 self._cleanup_master_ssh() 118 119 if self._master_ssh_proc is None: 120 # Create a shared socket in a temp location. 121 self._master_ssh_tempdir = tempfile.mkdtemp( 122 prefix='ssh-master') 123 124 # Setup flags and options for running the master ssh 125 # -N: Do not execute a remote command. 126 # ControlMaster: Spawn a master connection. 127 # ControlPath: The master connection socket path. 128 extra_flags = {'-N': None} 129 extra_options = { 130 'ControlMaster': True, 131 'ControlPath': self.socket_path, 132 'BatchMode': True 133 } 134 135 # Construct the command and start it. 136 master_cmd = self._formatter.format_ssh_local_command( 137 self._settings, 138 extra_flags=extra_flags, 139 extra_options=extra_options) 140 self.log.info('Starting master ssh connection.') 141 self._master_ssh_proc = job.run_async(master_cmd) 142 143 end_time = time.time() + timeout_seconds 144 145 while time.time() < end_time: 146 if os.path.exists(self.socket_path): 147 break 148 time.sleep(.2) 149 else: 150 self._cleanup_master_ssh() 151 raise Error('Master ssh connection timed out.') 152 153 def run(self, 154 command, 155 timeout=60, 156 ignore_status=False, 157 env=None, 158 io_encoding='utf-8', 159 attempts=2): 160 """Runs a remote command over ssh. 161 162 Will ssh to a remote host and run a command. This method will 163 block until the remote command is finished. 164 165 Args: 166 command: The command to execute over ssh. Can be either a string 167 or a list. 168 timeout: number seconds to wait for command to finish. 169 ignore_status: bool True to ignore the exit code of the remote 170 subprocess. Note that if you do ignore status codes, 171 you should handle non-zero exit codes explicitly. 172 env: dict environment variables to setup on the remote host. 173 io_encoding: str unicode encoding of command output. 174 attempts: Number of attempts before giving up on command failures. 175 176 Returns: 177 A job.Result containing the results of the ssh command. 178 179 Raises: 180 job.TimeoutError: When the remote command took to long to execute. 181 Error: When the ssh connection failed to be created. 182 CommandError: Ssh worked, but the command had an error executing. 183 """ 184 if attempts == 0: 185 return None 186 if env is None: 187 env = {} 188 189 try: 190 self.setup_master_ssh(self._settings.connect_timeout) 191 except Error: 192 self.log.warning('Failed to create master ssh connection, using ' 193 'normal ssh connection.') 194 195 extra_options = {'BatchMode': True} 196 if self._master_ssh_proc: 197 extra_options['ControlPath'] = self.socket_path 198 199 identifier = str(uuid.uuid4()) 200 full_command = 'echo "CONNECTED: %s"; %s' % (identifier, command) 201 202 terminal_command = self._formatter.format_command( 203 full_command, env, self._settings, extra_options=extra_options) 204 205 dns_retry_count = 2 206 while True: 207 result = job.run(terminal_command, 208 ignore_status=True, 209 timeout=timeout, 210 io_encoding=io_encoding) 211 output = result.stdout 212 213 # Check for a connected message to prevent false negatives. 214 valid_connection = re.search('^CONNECTED: %s' % identifier, 215 output, 216 flags=re.MULTILINE) 217 if valid_connection: 218 # Remove the first line that contains the connect message. 219 line_index = output.find('\n') + 1 220 if line_index == 0: 221 line_index = len(output) 222 real_output = output[line_index:].encode(io_encoding) 223 224 result = job.Result(command=result.command, 225 stdout=real_output, 226 stderr=result._raw_stderr, 227 exit_status=result.exit_status, 228 duration=result.duration, 229 did_timeout=result.did_timeout, 230 encoding=io_encoding) 231 if result.exit_status and not ignore_status: 232 raise job.Error(result) 233 return result 234 235 error_string = result.stderr 236 237 had_dns_failure = (result.exit_status == 255 and re.search( 238 r'^ssh: .*: Name or service not known', 239 error_string, 240 flags=re.MULTILINE)) 241 if had_dns_failure: 242 dns_retry_count -= 1 243 if not dns_retry_count: 244 raise Error('DNS failed to find host.', result) 245 self.log.debug('Failed to connect to host, retrying...') 246 else: 247 break 248 249 had_timeout = re.search( 250 r'^ssh: connect to host .* port .*: ' 251 r'Connection timed out\r$', 252 error_string, 253 flags=re.MULTILINE) 254 if had_timeout: 255 raise Error('Ssh timed out.', result) 256 257 permission_denied = 'Permission denied' in error_string 258 if permission_denied: 259 raise Error('Permission denied.', result) 260 261 unknown_host = re.search( 262 r'ssh: Could not resolve hostname .*: ' 263 r'Name or service not known', 264 error_string, 265 flags=re.MULTILINE) 266 if unknown_host: 267 raise Error('Unknown host.', result) 268 269 self.log.error('An unknown error has occurred. Job result: %s' % 270 result) 271 ping_output = job.run('ping %s -c 3 -w 1' % self._settings.hostname, 272 ignore_status=True) 273 self.log.error('Ping result: %s' % ping_output) 274 if attempts > 1: 275 self._cleanup_master_ssh() 276 self.run(command, timeout, ignore_status, env, io_encoding, 277 attempts - 1) 278 raise Error('The job failed for unknown reasons.', result) 279 280 def run_async(self, command, env=None): 281 """Starts up a background command over ssh. 282 283 Will ssh to a remote host and startup a command. This method will 284 block until there is confirmation that the remote command has started. 285 286 Args: 287 command: The command to execute over ssh. Can be either a string 288 or a list. 289 env: A dictonary of environment variables to setup on the remote 290 host. 291 292 Returns: 293 The result of the command to launch the background job. 294 295 Raises: 296 CmdTimeoutError: When the remote command took to long to execute. 297 SshTimeoutError: When the connection took to long to established. 298 SshPermissionDeniedError: When permission is not allowed on the 299 remote host. 300 """ 301 command = '(%s) < /dev/null > /dev/null 2>&1 & echo -n $!' % command 302 result = self.run(command, env=env) 303 return result 304 305 def close(self): 306 """Clean up open connections to remote host.""" 307 self._cleanup_master_ssh() 308 while self._tunnels: 309 self.close_ssh_tunnel(self._tunnels[0].local_port) 310 311 def _cleanup_master_ssh(self): 312 """ 313 Release all resources (process, temporary directory) used by an active 314 master SSH connection. 315 """ 316 # If a master SSH connection is running, kill it. 317 if self._master_ssh_proc is not None: 318 self.log.debug('Nuking master_ssh_job.') 319 self._master_ssh_proc.kill() 320 self._master_ssh_proc.wait() 321 self._master_ssh_proc = None 322 323 # Remove the temporary directory for the master SSH socket. 324 if self._master_ssh_tempdir is not None: 325 self.log.debug('Cleaning master_ssh_tempdir.') 326 shutil.rmtree(self._master_ssh_tempdir) 327 self._master_ssh_tempdir = None 328 329 def create_ssh_tunnel(self, port, local_port=None): 330 """Create an ssh tunnel from local_port to port. 331 332 This securely forwards traffic from local_port on this machine to the 333 remote SSH host at port. 334 335 Args: 336 port: remote port on the host. 337 local_port: local forwarding port, or None to pick an available 338 port. 339 340 Returns: 341 the created tunnel process. 342 """ 343 if not local_port: 344 local_port = host_utils.get_available_host_port() 345 else: 346 for tunnel in self._tunnels: 347 if tunnel.remote_port == port: 348 return tunnel.local_port 349 350 extra_flags = { 351 '-n': None, # Read from /dev/null for stdin 352 '-N': None, # Do not execute a remote command 353 '-q': None, # Suppress warnings and diagnostic commands 354 '-L': '%d:localhost:%d' % (local_port, port), 355 } 356 extra_options = dict() 357 if self._master_ssh_proc: 358 extra_options['ControlPath'] = self.socket_path 359 tunnel_cmd = self._formatter.format_ssh_local_command( 360 self._settings, 361 extra_flags=extra_flags, 362 extra_options=extra_options) 363 self.log.debug('Full tunnel command: %s', tunnel_cmd) 364 # Exec the ssh process directly so that when we deliver signals, we 365 # deliver them straight to the child process. 366 tunnel_proc = job.run_async(tunnel_cmd) 367 self.log.debug('Started ssh tunnel, local = %d remote = %d, pid = %d', 368 local_port, port, tunnel_proc.pid) 369 self._tunnels.append(_Tunnel(local_port, port, tunnel_proc)) 370 return local_port 371 372 def close_ssh_tunnel(self, local_port): 373 """Close a previously created ssh tunnel of a TCP port. 374 375 Args: 376 local_port: int port on localhost previously forwarded to the remote 377 host. 378 379 Returns: 380 integer port number this port was forwarded to on the remote host or 381 None if no tunnel was found. 382 """ 383 idx = None 384 for i, tunnel in enumerate(self._tunnels): 385 if tunnel.local_port == local_port: 386 idx = i 387 break 388 if idx is not None: 389 tunnel = self._tunnels.pop(idx) 390 tunnel.proc.kill() 391 tunnel.proc.wait() 392 return tunnel.remote_port 393 return None 394 395 def send_file(self, local_path, remote_path, ignore_status=False): 396 """Send a file from the local host to the remote host. 397 398 Args: 399 local_path: string path of file to send on local host. 400 remote_path: string path to copy file to on remote host. 401 ignore_status: Whether or not to ignore the command's exit_status. 402 """ 403 # TODO: This may belong somewhere else: b/32572515 404 user_host = self._formatter.format_host_name(self._settings) 405 job.run('scp %s %s:%s' % (local_path, user_host, remote_path), 406 ignore_status=ignore_status) 407 408 def pull_file(self, local_path, remote_path, ignore_status=False): 409 """Send a file from remote host to local host 410 411 Args: 412 local_path: string path of file to recv on local host 413 remote_path: string path to copy file from on remote host. 414 ignore_status: Whether or not to ignore the command's exit_status. 415 """ 416 user_host = self._formatter.format_host_name(self._settings) 417 job.run('scp %s:%s %s' % (user_host, remote_path, local_path), 418 ignore_status=ignore_status) 419 420 def find_free_port(self, interface_name='localhost'): 421 """Find a unused port on the remote host. 422 423 Note that this method is inherently racy, since it is impossible 424 to promise that the remote port will remain free. 425 426 Args: 427 interface_name: string name of interface to check whether a 428 port is used against. 429 430 Returns: 431 integer port number on remote interface that was free. 432 """ 433 # TODO: This may belong somewhere else: b/3257251 434 free_port_cmd = ( 435 'python -c "import socket; s=socket.socket(); ' 436 's.bind((\'%s\', 0)); print(s.getsockname()[1]); s.close()"' 437 ) % interface_name 438 port = int(self.run(free_port_cmd).stdout) 439 # Yield to the os to ensure the port gets cleaned up. 440 time.sleep(0.001) 441 return port 442