1#!/usr/bin/env python
2#
3# Copyright 2019 - The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Tests for acloud.internal.lib.ssh."""
18
19import subprocess
20import unittest
21import threading
22import time
23import mock
24
25from acloud import errors
26from acloud.internal import constants
27from acloud.internal.lib import driver_test_lib
28from acloud.internal.lib import ssh
29
30
31class SshTest(driver_test_lib.BaseDriverTest):
32    """Test ssh class."""
33
34    FAKE_SSH_PRIVATE_KEY_PATH = "/fake/acloud_rea"
35    FAKE_SSH_USER = "fake_user"
36    FAKE_IP = ssh.IP(external="1.1.1.1", internal="10.1.1.1")
37    FAKE_EXTRA_ARGS_SSH = "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22'"
38    FAKE_REPORT_INTERNAL_IP = True
39
40    def setUp(self):
41        """Set up the test."""
42        super(SshTest, self).setUp()
43        self.created_subprocess = mock.MagicMock()
44        self.created_subprocess.stdout = mock.MagicMock()
45        self.created_subprocess.stdout.readline = mock.MagicMock(return_value=b"")
46        self.created_subprocess.poll = mock.MagicMock(return_value=0)
47        self.created_subprocess.returncode = 0
48        self.created_subprocess.communicate = mock.MagicMock(return_value=
49                                                             ('', ''))
50
51    def testSSHExecuteWithRetry(self):
52        """test SSHExecuteWithRetry method."""
53        self.Patch(time, "sleep")
54        self.Patch(subprocess, "Popen",
55                   side_effect=subprocess.CalledProcessError(
56                       None, "ssh command fail."))
57        self.assertRaises(subprocess.CalledProcessError,
58                          ssh.ShellCmdWithRetry,
59                          "fake cmd")
60
61    def testGetBaseCmdWithInternalIP(self):
62        """Test get base command with internal ip."""
63        ssh_object = ssh.Ssh(ip=self.FAKE_IP,
64                             user=self.FAKE_SSH_USER,
65                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH,
66                             report_internal_ip=self.FAKE_REPORT_INTERNAL_IP)
67        expected_ssh_cmd = ("/usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
68                            "-o StrictHostKeyChecking=no -l fake_user 10.1.1.1")
69        self.assertEqual(ssh_object.GetBaseCmd(constants.SSH_BIN), expected_ssh_cmd)
70
71    def testGetBaseCmd(self):
72        """Test get base command."""
73        ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH)
74        expected_ssh_cmd = ("/usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
75                            "-o StrictHostKeyChecking=no -l fake_user 1.1.1.1")
76        self.assertEqual(ssh_object.GetBaseCmd(constants.SSH_BIN), expected_ssh_cmd)
77
78        expected_scp_cmd = ("/usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
79                            "-o StrictHostKeyChecking=no")
80        self.assertEqual(ssh_object.GetBaseCmd(constants.SCP_BIN), expected_scp_cmd)
81
82    # pylint: disable=no-member
83    def testSshRunCmd(self):
84        """Test ssh run command."""
85        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
86        ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH)
87        ssh_object.Run("command")
88        expected_cmd = ("exec /usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
89                        "-o StrictHostKeyChecking=no -l fake_user 1.1.1.1 command")
90        subprocess.Popen.assert_called_with(expected_cmd,
91                                            shell=True,
92                                            stderr=-2,
93                                            stdin=None,
94                                            stdout=-1)
95
96    def testSshRunCmdwithExtraArgs(self):
97        """test ssh rum command with extra command."""
98        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
99        ssh_object = ssh.Ssh(self.FAKE_IP,
100                             self.FAKE_SSH_USER,
101                             self.FAKE_SSH_PRIVATE_KEY_PATH,
102                             self.FAKE_EXTRA_ARGS_SSH)
103        ssh_object.Run("command")
104        expected_cmd = ("exec /usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
105                        "-o StrictHostKeyChecking=no "
106                        "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' "
107                        "-l fake_user 1.1.1.1 command")
108        subprocess.Popen.assert_called_with(expected_cmd,
109                                            shell=True,
110                                            stderr=-2,
111                                            stdin=None,
112                                            stdout=-1)
113
114    def testScpPullFileCmd(self):
115        """Test scp pull file command."""
116        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
117        ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH)
118        ssh_object.ScpPullFile("/tmp/test", "/tmp/test_1.log")
119        expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
120                        "-o StrictHostKeyChecking=no fake_user@1.1.1.1:/tmp/test /tmp/test_1.log")
121        subprocess.Popen.assert_called_with(expected_cmd,
122                                            shell=True,
123                                            stderr=-2,
124                                            stdin=None,
125                                            stdout=-1)
126
127    def testScpPullFileCmdwithExtraArgs(self):
128        """Test scp pull file command."""
129        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
130        ssh_object = ssh.Ssh(self.FAKE_IP,
131                             self.FAKE_SSH_USER,
132                             self.FAKE_SSH_PRIVATE_KEY_PATH,
133                             self.FAKE_EXTRA_ARGS_SSH)
134        ssh_object.ScpPullFile("/tmp/test", "/tmp/test_1.log")
135        expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
136                        "-o StrictHostKeyChecking=no "
137                        "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' "
138                        "fake_user@1.1.1.1:/tmp/test /tmp/test_1.log")
139        subprocess.Popen.assert_called_with(expected_cmd,
140                                            shell=True,
141                                            stderr=-2,
142                                            stdin=None,
143                                            stdout=-1)
144
145    def testScpPushFileCmd(self):
146        """Test scp push file command."""
147        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
148        ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH)
149        ssh_object.ScpPushFile("/tmp/test", "/tmp/test_1.log")
150        expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
151                        "-o StrictHostKeyChecking=no /tmp/test fake_user@1.1.1.1:/tmp/test_1.log")
152        subprocess.Popen.assert_called_with(expected_cmd,
153                                            shell=True,
154                                            stderr=-2,
155                                            stdin=None,
156                                            stdout=-1)
157
158    def testScpPushFileCmdwithExtraArgs(self):
159        """Test scp pull file command."""
160        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
161        ssh_object = ssh.Ssh(self.FAKE_IP,
162                             self.FAKE_SSH_USER,
163                             self.FAKE_SSH_PRIVATE_KEY_PATH,
164                             self.FAKE_EXTRA_ARGS_SSH)
165        ssh_object.ScpPushFile("/tmp/test", "/tmp/test_1.log")
166        expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
167                        "-o StrictHostKeyChecking=no "
168                        "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' "
169                        "/tmp/test fake_user@1.1.1.1:/tmp/test_1.log")
170        subprocess.Popen.assert_called_with(expected_cmd,
171                                            shell=True,
172                                            stderr=-2,
173                                            stdin=None,
174                                            stdout=-1)
175
176    # pylint: disable=protected-access
177    def testIPAddress(self):
178        """Test IP class to get ip address."""
179        # Internal ip case.
180        ssh_object = ssh.Ssh(ip=ssh.IP(external="1.1.1.1", internal="10.1.1.1"),
181                             user=self.FAKE_SSH_USER,
182                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH,
183                             report_internal_ip=True)
184        expected_ip = "10.1.1.1"
185        self.assertEqual(ssh_object._ip, expected_ip)
186
187        # External ip case.
188        ssh_object = ssh.Ssh(ip=ssh.IP(external="1.1.1.1", internal="10.1.1.1"),
189                             user=self.FAKE_SSH_USER,
190                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH)
191        expected_ip = "1.1.1.1"
192        self.assertEqual(ssh_object._ip, expected_ip)
193
194        # Only one ip case.
195        ssh_object = ssh.Ssh(ip=ssh.IP(ip="1.1.1.1"),
196                             user=self.FAKE_SSH_USER,
197                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH)
198        expected_ip = "1.1.1.1"
199        self.assertEqual(ssh_object._ip, expected_ip)
200
201    def testWaitForSsh(self):
202        """Test WaitForSsh."""
203        ssh_object = ssh.Ssh(ip=self.FAKE_IP,
204                             user=self.FAKE_SSH_USER,
205                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH,
206                             report_internal_ip=self.FAKE_REPORT_INTERNAL_IP)
207        self.Patch(ssh, "_SshCall", return_value=-1)
208        self.assertRaises(errors.DeviceConnectionError,
209                          ssh_object.WaitForSsh,
210                          timeout=1,
211                          max_retry=1)
212
213    def testSshCallWait(self):
214        """Test SshCallWait."""
215        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
216        self.Patch(threading, "Timer")
217        fake_cmd = "fake command"
218        ssh._SshCallWait(fake_cmd)
219        threading.Timer.assert_not_called()
220
221    def testSshCallWaitTimeout(self):
222        """Test SshCallWait with timeout."""
223        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
224        self.Patch(threading, "Timer")
225        fake_cmd = "fake command"
226        fake_timeout = 30
227        ssh._SshCallWait(fake_cmd, fake_timeout)
228        threading.Timer.assert_called_once()
229
230    def testSshCall(self):
231        """Test _SshCall."""
232        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
233        self.Patch(threading, "Timer")
234        fake_cmd = "fake command"
235        ssh._SshCall(fake_cmd)
236        threading.Timer.assert_not_called()
237
238    def testSshCallTimeout(self):
239        """Test SshCallWait with timeout."""
240        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
241        self.Patch(threading, "Timer")
242        fake_cmd = "fake command"
243        fake_timeout = 30
244        ssh._SshCall(fake_cmd, fake_timeout)
245        threading.Timer.assert_called_once()
246
247    def testSshLogOutput(self):
248        """Test _SshCall."""
249        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
250        self.Patch(threading, "Timer")
251        fake_cmd = "fake command"
252        ssh._SshLogOutput(fake_cmd)
253        threading.Timer.assert_not_called()
254
255    def testSshLogOutputTimeout(self):
256        """Test SshCallWait with timeout."""
257        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
258        self.Patch(threading, "Timer")
259        fake_cmd = "fake command"
260        fake_timeout = 30
261        ssh._SshLogOutput(fake_cmd, fake_timeout)
262        threading.Timer.assert_called_once()
263
264if __name__ == "__main__":
265    unittest.main()
266