1# Copyright 2016 - The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import collections
16import os
17import re
18import shutil
19import tempfile
20import threading
21import time
22import uuid
23
24from acts import logger
25from acts.controllers.utils_lib import host_utils
26from acts.controllers.utils_lib.ssh import formatter
27from acts.libs.proc import job
28
29
30class Error(Exception):
31    """An error occurred during an ssh operation."""
32
33
34class CommandError(Exception):
35    """An error occurred with the command.
36
37    Attributes:
38        result: The results of the ssh command that had the error.
39    """
40    def __init__(self, result):
41        """
42        Args:
43            result: The result of the ssh command that created the problem.
44        """
45        self.result = result
46
47    def __str__(self):
48        return 'cmd: %s\nstdout: %s\nstderr: %s' % (
49            self.result.command, self.result.stdout, self.result.stderr)
50
51
52_Tunnel = collections.namedtuple('_Tunnel',
53                                 ['local_port', 'remote_port', 'proc'])
54
55
56class SshConnection(object):
57    """Provides a connection to a remote machine through ssh.
58
59    Provides the ability to connect to a remote machine and execute a command
60    on it. The connection will try to establish a persistent connection When
61    a command is run. If the persistent connection fails it will attempt
62    to connect normally.
63    """
64    @property
65    def socket_path(self):
66        """Returns: The os path to the master socket file."""
67        return os.path.join(self._master_ssh_tempdir, 'socket')
68
69    def __init__(self, settings):
70        """
71        Args:
72            settings: The ssh settings to use for this connection.
73            formatter: The object that will handle formatting ssh command
74                       for use with the background job.
75        """
76        self._settings = settings
77        self._formatter = formatter.SshFormatter()
78        self._lock = threading.Lock()
79        self._master_ssh_proc = None
80        self._master_ssh_tempdir = None
81        self._tunnels = list()
82
83        def log_line(msg):
84            return '[SshConnection | %s] %s' % (self._settings.hostname, msg)
85
86        self.log = logger.create_logger(log_line)
87
88    def __enter__(self):
89        return self
90
91    def __exit__(self, _, __, ___):
92        self.close()
93
94    def __del__(self):
95        self.close()
96
97    def setup_master_ssh(self, timeout_seconds=5):
98        """Sets up the master ssh connection.
99
100        Sets up the initial master ssh connection if it has not already been
101        started.
102
103        Args:
104            timeout_seconds: The time to wait for the master ssh connection to
105            be made.
106
107        Raises:
108            Error: When setting up the master ssh connection fails.
109        """
110        with self._lock:
111            if self._master_ssh_proc is not None:
112                socket_path = self.socket_path
113                if (not os.path.exists(socket_path)
114                        or self._master_ssh_proc.poll() is not None):
115                    self.log.debug('Master ssh connection to %s is down.',
116                                   self._settings.hostname)
117                    self._cleanup_master_ssh()
118
119            if self._master_ssh_proc is None:
120                # Create a shared socket in a temp location.
121                self._master_ssh_tempdir = tempfile.mkdtemp(
122                    prefix='ssh-master')
123
124                # Setup flags and options for running the master ssh
125                # -N: Do not execute a remote command.
126                # ControlMaster: Spawn a master connection.
127                # ControlPath: The master connection socket path.
128                extra_flags = {'-N': None}
129                extra_options = {
130                    'ControlMaster': True,
131                    'ControlPath': self.socket_path,
132                    'BatchMode': True
133                }
134
135                # Construct the command and start it.
136                master_cmd = self._formatter.format_ssh_local_command(
137                    self._settings,
138                    extra_flags=extra_flags,
139                    extra_options=extra_options)
140                self.log.info('Starting master ssh connection.')
141                self._master_ssh_proc = job.run_async(master_cmd)
142
143                end_time = time.time() + timeout_seconds
144
145                while time.time() < end_time:
146                    if os.path.exists(self.socket_path):
147                        break
148                    time.sleep(.2)
149                else:
150                    self._cleanup_master_ssh()
151                    raise Error('Master ssh connection timed out.')
152
153    def run(self,
154            command,
155            timeout=60,
156            ignore_status=False,
157            env=None,
158            io_encoding='utf-8',
159            attempts=2):
160        """Runs a remote command over ssh.
161
162        Will ssh to a remote host and run a command. This method will
163        block until the remote command is finished.
164
165        Args:
166            command: The command to execute over ssh. Can be either a string
167                     or a list.
168            timeout: number seconds to wait for command to finish.
169            ignore_status: bool True to ignore the exit code of the remote
170                           subprocess.  Note that if you do ignore status codes,
171                           you should handle non-zero exit codes explicitly.
172            env: dict environment variables to setup on the remote host.
173            io_encoding: str unicode encoding of command output.
174            attempts: Number of attempts before giving up on command failures.
175
176        Returns:
177            A job.Result containing the results of the ssh command.
178
179        Raises:
180            job.TimeoutError: When the remote command took to long to execute.
181            Error: When the ssh connection failed to be created.
182            CommandError: Ssh worked, but the command had an error executing.
183        """
184        if attempts == 0:
185            return None
186        if env is None:
187            env = {}
188
189        try:
190            self.setup_master_ssh(self._settings.connect_timeout)
191        except Error:
192            self.log.warning('Failed to create master ssh connection, using '
193                             'normal ssh connection.')
194
195        extra_options = {'BatchMode': True}
196        if self._master_ssh_proc:
197            extra_options['ControlPath'] = self.socket_path
198
199        identifier = str(uuid.uuid4())
200        full_command = 'echo "CONNECTED: %s"; %s' % (identifier, command)
201
202        terminal_command = self._formatter.format_command(
203            full_command, env, self._settings, extra_options=extra_options)
204
205        dns_retry_count = 2
206        while True:
207            result = job.run(terminal_command,
208                             ignore_status=True,
209                             timeout=timeout,
210                             io_encoding=io_encoding)
211            output = result.stdout
212
213            # Check for a connected message to prevent false negatives.
214            valid_connection = re.search('^CONNECTED: %s' % identifier,
215                                         output,
216                                         flags=re.MULTILINE)
217            if valid_connection:
218                # Remove the first line that contains the connect message.
219                line_index = output.find('\n') + 1
220                if line_index == 0:
221                    line_index = len(output)
222                real_output = output[line_index:].encode(io_encoding)
223
224                result = job.Result(command=result.command,
225                                    stdout=real_output,
226                                    stderr=result._raw_stderr,
227                                    exit_status=result.exit_status,
228                                    duration=result.duration,
229                                    did_timeout=result.did_timeout,
230                                    encoding=io_encoding)
231                if result.exit_status and not ignore_status:
232                    raise job.Error(result)
233                return result
234
235            error_string = result.stderr
236
237            had_dns_failure = (result.exit_status == 255 and re.search(
238                r'^ssh: .*: Name or service not known',
239                error_string,
240                flags=re.MULTILINE))
241            if had_dns_failure:
242                dns_retry_count -= 1
243                if not dns_retry_count:
244                    raise Error('DNS failed to find host.', result)
245                self.log.debug('Failed to connect to host, retrying...')
246            else:
247                break
248
249        had_timeout = re.search(
250            r'^ssh: connect to host .* port .*: '
251            r'Connection timed out\r$',
252            error_string,
253            flags=re.MULTILINE)
254        if had_timeout:
255            raise Error('Ssh timed out.', result)
256
257        permission_denied = 'Permission denied' in error_string
258        if permission_denied:
259            raise Error('Permission denied.', result)
260
261        unknown_host = re.search(
262            r'ssh: Could not resolve hostname .*: '
263            r'Name or service not known',
264            error_string,
265            flags=re.MULTILINE)
266        if unknown_host:
267            raise Error('Unknown host.', result)
268
269        self.log.error('An unknown error has occurred. Job result: %s' %
270                       result)
271        ping_output = job.run('ping %s -c 3 -w 1' % self._settings.hostname,
272                              ignore_status=True)
273        self.log.error('Ping result: %s' % ping_output)
274        if attempts > 1:
275            self._cleanup_master_ssh()
276            self.run(command, timeout, ignore_status, env, io_encoding,
277                     attempts - 1)
278        raise Error('The job failed for unknown reasons.', result)
279
280    def run_async(self, command, env=None):
281        """Starts up a background command over ssh.
282
283        Will ssh to a remote host and startup a command. This method will
284        block until there is confirmation that the remote command has started.
285
286        Args:
287            command: The command to execute over ssh. Can be either a string
288                     or a list.
289            env: A dictonary of environment variables to setup on the remote
290                 host.
291
292        Returns:
293            The result of the command to launch the background job.
294
295        Raises:
296            CmdTimeoutError: When the remote command took to long to execute.
297            SshTimeoutError: When the connection took to long to established.
298            SshPermissionDeniedError: When permission is not allowed on the
299                                      remote host.
300        """
301        command = '(%s) < /dev/null > /dev/null 2>&1 & echo -n $!' % command
302        result = self.run(command, env=env)
303        return result
304
305    def close(self):
306        """Clean up open connections to remote host."""
307        self._cleanup_master_ssh()
308        while self._tunnels:
309            self.close_ssh_tunnel(self._tunnels[0].local_port)
310
311    def _cleanup_master_ssh(self):
312        """
313        Release all resources (process, temporary directory) used by an active
314        master SSH connection.
315        """
316        # If a master SSH connection is running, kill it.
317        if self._master_ssh_proc is not None:
318            self.log.debug('Nuking master_ssh_job.')
319            self._master_ssh_proc.kill()
320            self._master_ssh_proc.wait()
321            self._master_ssh_proc = None
322
323        # Remove the temporary directory for the master SSH socket.
324        if self._master_ssh_tempdir is not None:
325            self.log.debug('Cleaning master_ssh_tempdir.')
326            shutil.rmtree(self._master_ssh_tempdir)
327            self._master_ssh_tempdir = None
328
329    def create_ssh_tunnel(self, port, local_port=None):
330        """Create an ssh tunnel from local_port to port.
331
332        This securely forwards traffic from local_port on this machine to the
333        remote SSH host at port.
334
335        Args:
336            port: remote port on the host.
337            local_port: local forwarding port, or None to pick an available
338                        port.
339
340        Returns:
341            the created tunnel process.
342        """
343        if not local_port:
344            local_port = host_utils.get_available_host_port()
345        else:
346            for tunnel in self._tunnels:
347                if tunnel.remote_port == port:
348                    return tunnel.local_port
349
350        extra_flags = {
351            '-n': None,  # Read from /dev/null for stdin
352            '-N': None,  # Do not execute a remote command
353            '-q': None,  # Suppress warnings and diagnostic commands
354            '-L': '%d:localhost:%d' % (local_port, port),
355        }
356        extra_options = dict()
357        if self._master_ssh_proc:
358            extra_options['ControlPath'] = self.socket_path
359        tunnel_cmd = self._formatter.format_ssh_local_command(
360            self._settings,
361            extra_flags=extra_flags,
362            extra_options=extra_options)
363        self.log.debug('Full tunnel command: %s', tunnel_cmd)
364        # Exec the ssh process directly so that when we deliver signals, we
365        # deliver them straight to the child process.
366        tunnel_proc = job.run_async(tunnel_cmd)
367        self.log.debug('Started ssh tunnel, local = %d remote = %d, pid = %d',
368                       local_port, port, tunnel_proc.pid)
369        self._tunnels.append(_Tunnel(local_port, port, tunnel_proc))
370        return local_port
371
372    def close_ssh_tunnel(self, local_port):
373        """Close a previously created ssh tunnel of a TCP port.
374
375        Args:
376            local_port: int port on localhost previously forwarded to the remote
377                        host.
378
379        Returns:
380            integer port number this port was forwarded to on the remote host or
381            None if no tunnel was found.
382        """
383        idx = None
384        for i, tunnel in enumerate(self._tunnels):
385            if tunnel.local_port == local_port:
386                idx = i
387                break
388        if idx is not None:
389            tunnel = self._tunnels.pop(idx)
390            tunnel.proc.kill()
391            tunnel.proc.wait()
392            return tunnel.remote_port
393        return None
394
395    def send_file(self, local_path, remote_path, ignore_status=False):
396        """Send a file from the local host to the remote host.
397
398        Args:
399            local_path: string path of file to send on local host.
400            remote_path: string path to copy file to on remote host.
401            ignore_status: Whether or not to ignore the command's exit_status.
402        """
403        # TODO: This may belong somewhere else: b/32572515
404        user_host = self._formatter.format_host_name(self._settings)
405        job.run('scp %s %s:%s' % (local_path, user_host, remote_path),
406                ignore_status=ignore_status)
407
408    def pull_file(self, local_path, remote_path, ignore_status=False):
409        """Send a file from remote host to local host
410
411        Args:
412            local_path: string path of file to recv on local host
413            remote_path: string path to copy file from on remote host.
414            ignore_status: Whether or not to ignore the command's exit_status.
415        """
416        user_host = self._formatter.format_host_name(self._settings)
417        job.run('scp %s:%s %s' % (user_host, remote_path, local_path),
418                ignore_status=ignore_status)
419
420    def find_free_port(self, interface_name='localhost'):
421        """Find a unused port on the remote host.
422
423        Note that this method is inherently racy, since it is impossible
424        to promise that the remote port will remain free.
425
426        Args:
427            interface_name: string name of interface to check whether a
428                            port is used against.
429
430        Returns:
431            integer port number on remote interface that was free.
432        """
433        # TODO: This may belong somewhere else: b/3257251
434        free_port_cmd = (
435            'python -c "import socket; s=socket.socket(); '
436            's.bind((\'%s\', 0)); print(s.getsockname()[1]); s.close()"'
437        ) % interface_name
438        port = int(self.run(free_port_cmd).stdout)
439        # Yield to the os to ensure the port gets cleaned up.
440        time.sleep(0.001)
441        return port
442