1# Copyright 2019 - The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Ssh Utilities."""
15from __future__ import print_function
16import logging
17
18import subprocess
19import sys
20import threading
21
22from acloud import errors
23from acloud.internal import constants
24from acloud.internal.lib import utils
25
26logger = logging.getLogger(__name__)
27
28_SSH_CMD = ("-i %(rsa_key_file)s "
29            "-q -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no")
30_SSH_IDENTITY = "-l %(login_user)s %(ip_addr)s"
31_SSH_CMD_MAX_RETRY = 5
32_SSH_CMD_RETRY_SLEEP = 3
33_CONNECTION_TIMEOUT = 10
34
35
36def _SshCallWait(cmd, timeout=None):
37    """Runs a single SSH command.
38
39    - SSH returns code 0 for "Successful execution".
40    - Use wait() until the process is complete without receiving any output.
41
42    Args:
43        cmd: String of the full SSH command to run, including the SSH binary
44             and its arguments.
45        timeout: Optional integer, number of seconds to give
46
47    Returns:
48        An exit status of 0 indicates that it ran successfully.
49    """
50    logger.info("Running command \"%s\"", cmd)
51    process = subprocess.Popen(cmd, shell=True, stdin=None,
52                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
53    if timeout:
54        # TODO: if process is killed, out error message to log.
55        timer = threading.Timer(timeout, process.kill)
56        timer.start()
57    process.wait()
58    if timeout:
59        timer.cancel()
60    return process.returncode
61
62
63def _SshCall(cmd, timeout=None):
64    """Runs a single SSH command.
65
66    - SSH returns code 0 for "Successful execution".
67    - Use communicate() until the process and the child thread are complete.
68
69    Args:
70        cmd: String of the full SSH command to run, including the SSH binary
71             and its arguments.
72        timeout: Optional integer, number of seconds to give
73
74    Returns:
75        An exit status of 0 indicates that it ran successfully.
76    """
77    logger.info("Running command \"%s\"", cmd)
78    process = subprocess.Popen(cmd, shell=True, stdin=None,
79                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
80    if timeout:
81        # TODO: if process is killed, out error message to log.
82        timer = threading.Timer(timeout, process.kill)
83        timer.start()
84    process.communicate()
85    if timeout:
86        timer.cancel()
87    return process.returncode
88
89
90def _SshLogOutput(cmd, timeout=None, show_output=False):
91    """Runs a single SSH command while logging its output and processes its return code.
92
93    Output is streamed to the log at the debug level for more interactive debugging.
94    SSH returns error code 255 for "failed to connect", so this is interpreted as a failure in
95    SSH rather than a failure on the target device and this is converted to a different exception
96    type.
97
98    Args:
99        cmd: String of the full SSH command to run, including the SSH binary and its arguments.
100        timeout: Optional integer, number of seconds to give.
101        show_output: Boolean, True to show command output in screen.
102
103    Raises:
104        errors.DeviceConnectionError: Failed to connect to the GCE instance.
105        subprocess.CalledProc: The process exited with an error on the instance.
106    """
107    # Use "exec" to let cmd to inherit the shell process, instead of having the
108    # shell launch a child process which does not get killed.
109    cmd = "exec " + cmd
110    logger.info("Running command \"%s\"", cmd)
111    process = subprocess.Popen(cmd, shell=True, stdin=None,
112                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
113    if timeout:
114        # TODO: if process is killed, out error message to log.
115        timer = threading.Timer(timeout, process.kill)
116        timer.start()
117    stdout, _ = process.communicate()
118    if stdout:
119        if show_output or process.returncode != 0:
120            print(stdout.strip(), file=sys.stderr)
121        else:
122            # fetch_cvd and launch_cvd can be noisy, so left at debug
123            logger.debug(stdout.strip())
124    if timeout:
125        timer.cancel()
126    if process.returncode == 255:
127        raise errors.DeviceConnectionError(
128            "Failed to send command to instance (%s)" % cmd)
129    elif process.returncode != 0:
130        raise subprocess.CalledProcessError(process.returncode, cmd)
131
132
133def ShellCmdWithRetry(cmd, timeout=None, show_output=False,
134                      retry=_SSH_CMD_MAX_RETRY):
135    """Runs a shell command on remote device.
136
137    If the network is unstable and causes SSH connect fail, it will retry. When
138    it retry in a short time, you may encounter unstable network. We will use
139    the mechanism of RETRY_BACKOFF_FACTOR. The retry time for each failure is
140    times * retries.
141
142    Args:
143        cmd: String of the full SSH command to run, including the SSH binary and its arguments.
144        timeout: Optional integer, number of seconds to give.
145        show_output: Boolean, True to show command output in screen.
146        retry: Integer, the retry times.
147
148    Raises:
149        errors.DeviceConnectionError: For any non-zero return code of
150                                      remote_cmd.
151    """
152    utils.RetryExceptionType(
153        exception_types=(errors.DeviceConnectionError, subprocess.CalledProcessError),
154        max_retries=retry,
155        functor=_SshLogOutput,
156        sleep_multiplier=_SSH_CMD_RETRY_SLEEP,
157        retry_backoff_factor=utils.DEFAULT_RETRY_BACKOFF_FACTOR,
158        cmd=cmd,
159        timeout=timeout,
160        show_output=show_output)
161
162
163class IP(object):
164    """ A class that control the IP address."""
165    def __init__(self, external=None, internal=None, ip=None):
166        """Init for IP.
167            Args:
168                external: String, external ip.
169                internal: String, internal ip.
170                ip: String, default ip to set for either external and internal
171                if neither is set.
172        """
173        self.external = external or ip
174        self.internal = internal or ip
175
176
177class Ssh(object):
178    """A class that control the remote instance via the IP address.
179
180    Attributes:
181        _ip: an IP object.
182        _user: String of user login into the instance.
183        _ssh_private_key_path: Path to the private key file.
184        _extra_args_ssh_tunnel: String, extra args for ssh or scp.
185    """
186    def __init__(self, ip, user, ssh_private_key_path,
187                 extra_args_ssh_tunnel=None, report_internal_ip=False):
188        self._ip = ip.internal if report_internal_ip else ip.external
189        self._user = user
190        self._ssh_private_key_path = ssh_private_key_path
191        self._extra_args_ssh_tunnel = extra_args_ssh_tunnel
192
193    def Run(self, target_command, timeout=None, show_output=False,
194            retry=_SSH_CMD_MAX_RETRY):
195        """Run a shell command over SSH on a remote instance.
196
197        Example:
198            ssh:
199                base_cmd_list is ["ssh", "-i", "~/private_key_path" ,"-l" , "user", "1.1.1.1"]
200                target_command is "remote command"
201            scp:
202                base_cmd_list is ["scp", "-i", "~/private_key_path"]
203                target_command is "{src_file} {dst_file}"
204
205        Args:
206            target_command: String, text of command to run on the remote instance.
207            timeout: Integer, the maximum time to wait for the command to respond.
208            show_output: Boolean, True to show command output in screen.
209            retry: Integer, the retry times.
210        """
211        ShellCmdWithRetry(self.GetBaseCmd(constants.SSH_BIN) + " " + target_command,
212                          timeout,
213                          show_output,
214                          retry)
215
216    def GetBaseCmd(self, execute_bin):
217        """Get a base command over SSH on a remote instance.
218
219        Example:
220            execute bin is ssh:
221                ssh -i ~/private_key_path $extra_args -l user 1.1.1.1
222            execute bin is scp:
223                scp -i ~/private_key_path $extra_args
224
225        Args:
226            execute_bin: String, execute type, e.g. ssh or scp.
227
228        Returns:
229            Strings of base connection command.
230
231        Raises:
232            errors.UnknownType: Don't support the execute bin.
233        """
234        base_cmd = [utils.FindExecutable(execute_bin)]
235        base_cmd.append(_SSH_CMD % {"rsa_key_file": self._ssh_private_key_path})
236        if self._extra_args_ssh_tunnel:
237            base_cmd.append(self._extra_args_ssh_tunnel)
238
239        if execute_bin == constants.SSH_BIN:
240            base_cmd.append(_SSH_IDENTITY %
241                            {"login_user":self._user, "ip_addr":self._ip})
242            return " ".join(base_cmd)
243        if execute_bin == constants.SCP_BIN:
244            return " ".join(base_cmd)
245
246        raise errors.UnknownType("Don't support the execute bin %s." % execute_bin)
247
248    def GetCmdOutput(self, cmd):
249        """Runs a single SSH command and get its output.
250
251        Args:
252            cmd: String, text of command to run on the remote instance.
253
254        Returns:
255            String of the command output.
256        """
257        ssh_cmd = "exec " + self.GetBaseCmd(constants.SSH_BIN) + " " + cmd
258        logger.info("Running command \"%s\"", ssh_cmd)
259        process = subprocess.Popen(ssh_cmd, shell=True, stdin=None,
260                                   stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
261        stdout, _ = process.communicate()
262        return stdout.decode()
263
264    def CheckSshConnection(self, timeout):
265        """Run remote 'uptime' ssh command to check ssh connection.
266
267        Args:
268            timeout: Integer, the maximum time to wait for the command to respond.
269
270        Raises:
271            errors.DeviceConnectionError: Ssh isn't ready in the remote instance.
272        """
273        remote_cmd = [self.GetBaseCmd(constants.SSH_BIN)]
274        remote_cmd.append("uptime")
275
276        if _SshCallWait(" ".join(remote_cmd), timeout) == 0:
277            return
278        raise errors.DeviceConnectionError(
279            "Ssh isn't ready in the remote instance.")
280
281    @utils.TimeExecute(function_description="Waiting for SSH server")
282    def WaitForSsh(self, timeout=None, max_retry=_SSH_CMD_MAX_RETRY):
283        """Wait until the remote instance is ready to accept commands over SSH.
284
285        Args:
286            timeout: Integer, the maximum time in seconds to wait for the
287                     command to respond.
288            max_retry: Integer, the maximum number of retry.
289
290        Raises:
291            errors.DeviceConnectionError: Ssh isn't ready in the remote instance.
292        """
293        ssh_timeout = timeout or constants.DEFAULT_SSH_TIMEOUT
294        sleep_multiplier = ssh_timeout / sum(range(max_retry + 1))
295        logger.debug("Retry with interval time: %s secs", str(sleep_multiplier))
296        utils.RetryExceptionType(
297            exception_types=errors.DeviceConnectionError,
298            max_retries=max_retry,
299            functor=self.CheckSshConnection,
300            sleep_multiplier=sleep_multiplier,
301            retry_backoff_factor=utils.DEFAULT_RETRY_BACKOFF_FACTOR,
302            timeout=_CONNECTION_TIMEOUT)
303
304    def ScpPushFile(self, src_file, dst_file):
305        """Scp push file to remote.
306
307        Args:
308            src_file: The source file path to be pulled.
309            dst_file: The destination file path the file is pulled to.
310        """
311        scp_command = [self.GetBaseCmd(constants.SCP_BIN)]
312        scp_command.append(src_file)
313        scp_command.append("%s@%s:%s" %(self._user, self._ip, dst_file))
314        ShellCmdWithRetry(" ".join(scp_command))
315
316    def ScpPullFile(self, src_file, dst_file):
317        """Scp pull file from remote.
318
319        Args:
320            src_file: The source file path to be pulled.
321            dst_file: The destination file path the file is pulled to.
322        """
323        scp_command = [self.GetBaseCmd(constants.SCP_BIN)]
324        scp_command.append("%s@%s:%s" %(self._user, self._ip, src_file))
325        scp_command.append(dst_file)
326        ShellCmdWithRetry(" ".join(scp_command))
327