1#!/usr/bin/env python
2#
3# Copyright 2016 - The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""Tests for acloud.internal.lib.utils."""
17
18import errno
19import getpass
20import grp
21import os
22import shutil
23import subprocess
24import tempfile
25import time
26import webbrowser
27
28import unittest
29import six
30import mock
31
32from acloud import errors
33from acloud.internal.lib import driver_test_lib
34from acloud.internal.lib import utils
35
36
37# Tkinter may not be supported so mock it out.
38try:
39    import Tkinter
40except ImportError:
41    Tkinter = mock.Mock()
42
43
44class FakeTkinter(object):
45    """Fake implementation of Tkinter.Tk()"""
46
47    def __init__(self, width=None, height=None):
48        self.width = width
49        self.height = height
50
51    # pylint: disable=invalid-name
52    def winfo_screenheight(self):
53        """Return the screen height."""
54        return self.height
55
56    # pylint: disable=invalid-name
57    def winfo_screenwidth(self):
58        """Return the screen width."""
59        return self.width
60
61
62# pylint: disable=too-many-public-methods
63class UtilsTest(driver_test_lib.BaseDriverTest):
64    """Test Utils."""
65
66    def TestTempDirSuccess(self):
67        """Test create a temp dir."""
68        self.Patch(os, "chmod")
69        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
70        self.Patch(shutil, "rmtree")
71        with utils.TempDir():
72            pass
73        # Verify.
74        tempfile.mkdtemp.assert_called_once()  # pylint: disable=no-member
75        shutil.rmtree.assert_called_with("/tmp/tempdir")  # pylint: disable=no-member
76
77    def TestTempDirExceptionRaised(self):
78        """Test create a temp dir and exception is raised within with-clause."""
79        self.Patch(os, "chmod")
80        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
81        self.Patch(shutil, "rmtree")
82
83        class ExpectedException(Exception):
84            """Expected exception."""
85
86        def _Call():
87            with utils.TempDir():
88                raise ExpectedException("Expected exception.")
89
90        # Verify. ExpectedException should be raised.
91        self.assertRaises(ExpectedException, _Call)
92        tempfile.mkdtemp.assert_called_once()  # pylint: disable=no-member
93        shutil.rmtree.assert_called_with("/tmp/tempdir")  #pylint: disable=no-member
94
95    def testTempDirWhenDeleteTempDirNoLongerExist(self):  # pylint: disable=invalid-name
96        """Test create a temp dir and dir no longer exists during deletion."""
97        self.Patch(os, "chmod")
98        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
99        expected_error = EnvironmentError()
100        expected_error.errno = errno.ENOENT
101        self.Patch(shutil, "rmtree", side_effect=expected_error)
102
103        def _Call():
104            with utils.TempDir():
105                pass
106
107        # Verify no exception should be raised when rmtree raises
108        # EnvironmentError with errno.ENOENT, i.e.
109        # directory no longer exists.
110        _Call()
111        tempfile.mkdtemp.assert_called_once()  #pylint: disable=no-member
112        shutil.rmtree.assert_called_with("/tmp/tempdir")  #pylint: disable=no-member
113
114    def testTempDirWhenDeleteEncounterError(self):
115        """Test create a temp dir and encoutered error during deletion."""
116        self.Patch(os, "chmod")
117        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
118        expected_error = OSError("Expected OS Error")
119        self.Patch(shutil, "rmtree", side_effect=expected_error)
120
121        def _Call():
122            with utils.TempDir():
123                pass
124
125        # Verify OSError should be raised.
126        self.assertRaises(OSError, _Call)
127        tempfile.mkdtemp.assert_called_once()  #pylint: disable=no-member
128        shutil.rmtree.assert_called_with("/tmp/tempdir")  #pylint: disable=no-member
129
130    def testTempDirOrininalErrorRaised(self):
131        """Test original error is raised even if tmp dir deletion failed."""
132        self.Patch(os, "chmod")
133        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
134        expected_error = OSError("Expected OS Error")
135        self.Patch(shutil, "rmtree", side_effect=expected_error)
136
137        class ExpectedException(Exception):
138            """Expected exception."""
139
140        def _Call():
141            with utils.TempDir():
142                raise ExpectedException("Expected Exception")
143
144        # Verify.
145        # ExpectedException should be raised, and OSError
146        # should not be raised.
147        self.assertRaises(ExpectedException, _Call)
148        tempfile.mkdtemp.assert_called_once()  #pylint: disable=no-member
149        shutil.rmtree.assert_called_with("/tmp/tempdir")  #pylint: disable=no-member
150
151    def testCreateSshKeyPairKeyAlreadyExists(self):  #pylint: disable=invalid-name
152        """Test when the key pair already exists."""
153        public_key = "/fake/public_key"
154        private_key = "/fake/private_key"
155        self.Patch(os.path, "exists", side_effect=[True, True])
156        self.Patch(subprocess, "check_call")
157        self.Patch(os, "makedirs", return_value=True)
158        utils.CreateSshKeyPairIfNotExist(private_key, public_key)
159        self.assertEqual(subprocess.check_call.call_count, 0)  #pylint: disable=no-member
160
161    def testCreateSshKeyPairKeyAreCreated(self):
162        """Test when the key pair created."""
163        public_key = "/fake/public_key"
164        private_key = "/fake/private_key"
165        self.Patch(os.path, "exists", return_value=False)
166        self.Patch(os, "makedirs", return_value=True)
167        self.Patch(subprocess, "check_call")
168        self.Patch(os, "rename")
169        utils.CreateSshKeyPairIfNotExist(private_key, public_key)
170        self.assertEqual(subprocess.check_call.call_count, 1)  #pylint: disable=no-member
171        subprocess.check_call.assert_called_with(  #pylint: disable=no-member
172            utils.SSH_KEYGEN_CMD +
173            ["-C", getpass.getuser(), "-f", private_key],
174            stdout=mock.ANY,
175            stderr=mock.ANY)
176
177    def testCreatePublicKeyAreCreated(self):
178        """Test when the PublicKey created."""
179        public_key = "/fake/public_key"
180        private_key = "/fake/private_key"
181        self.Patch(os.path, "exists", side_effect=[False, True, True])
182        self.Patch(os, "makedirs", return_value=True)
183        mock_open = mock.mock_open(read_data=public_key)
184        self.Patch(subprocess, "check_output")
185        self.Patch(os, "rename")
186        with mock.patch.object(six.moves.builtins, "open", mock_open):
187            utils.CreateSshKeyPairIfNotExist(private_key, public_key)
188        self.assertEqual(subprocess.check_output.call_count, 1)  #pylint: disable=no-member
189        subprocess.check_output.assert_called_with(  #pylint: disable=no-member
190            utils.SSH_KEYGEN_PUB_CMD +["-f", private_key])
191
192    def TestRetryOnException(self):
193        """Test Retry."""
194
195        def _IsValueError(exc):
196            return isinstance(exc, ValueError)
197
198        num_retry = 5
199
200        @utils.RetryOnException(_IsValueError, num_retry)
201        def _RaiseAndRetry(sentinel):
202            sentinel.alert()
203            raise ValueError("Fake error.")
204
205        sentinel = mock.MagicMock()
206        self.assertRaises(ValueError, _RaiseAndRetry, sentinel)
207        self.assertEqual(1 + num_retry, sentinel.alert.call_count)
208
209    def testRetryExceptionType(self):
210        """Test RetryExceptionType function."""
211
212        def _RaiseAndRetry(sentinel):
213            sentinel.alert()
214            raise ValueError("Fake error.")
215
216        num_retry = 5
217        sentinel = mock.MagicMock()
218        self.assertRaises(
219            ValueError,
220            utils.RetryExceptionType, (KeyError, ValueError),
221            num_retry,
222            _RaiseAndRetry,
223            0, # sleep_multiplier
224            1, # retry_backoff_factor
225            sentinel=sentinel)
226        self.assertEqual(1 + num_retry, sentinel.alert.call_count)
227
228    def testRetry(self):
229        """Test Retry."""
230        mock_sleep = self.Patch(time, "sleep")
231
232        def _RaiseAndRetry(sentinel):
233            sentinel.alert()
234            raise ValueError("Fake error.")
235
236        num_retry = 5
237        sentinel = mock.MagicMock()
238        self.assertRaises(
239            ValueError,
240            utils.RetryExceptionType, (ValueError, KeyError),
241            num_retry,
242            _RaiseAndRetry,
243            1, # sleep_multiplier
244            2, # retry_backoff_factor
245            sentinel=sentinel)
246
247        self.assertEqual(1 + num_retry, sentinel.alert.call_count)
248        mock_sleep.assert_has_calls(
249            [
250                mock.call(1),
251                mock.call(2),
252                mock.call(4),
253                mock.call(8),
254                mock.call(16)
255            ])
256
257    @mock.patch.object(six.moves, "input")
258    def testGetAnswerFromList(self, mock_raw_input):
259        """Test GetAnswerFromList."""
260        answer_list = ["image1.zip", "image2.zip", "image3.zip"]
261        mock_raw_input.return_value = 0
262        with self.assertRaises(SystemExit):
263            utils.GetAnswerFromList(answer_list)
264        mock_raw_input.side_effect = [1, 2, 3, 4]
265        self.assertEqual(utils.GetAnswerFromList(answer_list),
266                         ["image1.zip"])
267        self.assertEqual(utils.GetAnswerFromList(answer_list),
268                         ["image2.zip"])
269        self.assertEqual(utils.GetAnswerFromList(answer_list),
270                         ["image3.zip"])
271        self.assertEqual(utils.GetAnswerFromList(answer_list,
272                                                 enable_choose_all=True),
273                         answer_list)
274
275    @unittest.skipIf(isinstance(Tkinter, mock.Mock), "Tkinter mocked out, test case not needed.")
276    @mock.patch.object(Tkinter, "Tk")
277    def testCalculateVNCScreenRatio(self, mock_tk):
278        """Test Calculating the scale ratio of VNC display."""
279        # Get scale-down ratio if screen height is smaller than AVD height.
280        mock_tk.return_value = FakeTkinter(height=800, width=1200)
281        avd_h = 1920
282        avd_w = 1080
283        self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.4)
284
285        # Get scale-down ratio if screen width is smaller than AVD width.
286        mock_tk.return_value = FakeTkinter(height=800, width=1200)
287        avd_h = 900
288        avd_w = 1920
289        self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.6)
290
291        # Scale ratio = 1 if screen is larger than AVD.
292        mock_tk.return_value = FakeTkinter(height=1080, width=1920)
293        avd_h = 800
294        avd_w = 1280
295        self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 1)
296
297        # Get the scale if ratio of width is smaller than the
298        # ratio of height.
299        mock_tk.return_value = FakeTkinter(height=1200, width=800)
300        avd_h = 1920
301        avd_w = 1080
302        self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.6)
303
304    # pylint: disable=protected-access
305    def testCheckUserInGroups(self):
306        """Test CheckUserInGroups."""
307        self.Patch(os, "getgroups", return_value=[1, 2, 3])
308        gr1 = mock.MagicMock()
309        gr1.gr_name = "fake_gr_1"
310        gr2 = mock.MagicMock()
311        gr2.gr_name = "fake_gr_2"
312        gr3 = mock.MagicMock()
313        gr3.gr_name = "fake_gr_3"
314        self.Patch(grp, "getgrgid", side_effect=[gr1, gr2, gr3])
315
316        # User in all required groups should return true.
317        self.assertTrue(
318            utils.CheckUserInGroups(
319                ["fake_gr_1", "fake_gr_2"]))
320
321        # User not in all required groups should return False.
322        self.Patch(grp, "getgrgid", side_effect=[gr1, gr2, gr3])
323        self.assertFalse(
324            utils.CheckUserInGroups(
325                ["fake_gr_1", "fake_gr_4"]))
326
327    @mock.patch.object(utils, "CheckUserInGroups")
328    def testAddUserGroupsToCmd(self, mock_user_group):
329        """Test AddUserGroupsToCmd."""
330        command = "test_command"
331        groups = ["group1", "group2"]
332        # Don't add user group in command
333        mock_user_group.return_value = True
334        expected_value = "test_command"
335        self.assertEqual(expected_value, utils.AddUserGroupsToCmd(command,
336                                                                  groups))
337
338        # Add user group in command
339        mock_user_group.return_value = False
340        expected_value = "sg group1 <<EOF\nsg group2\ntest_command\nEOF"
341        self.assertEqual(expected_value, utils.AddUserGroupsToCmd(command,
342                                                                  groups))
343
344    # pylint: disable=invalid-name
345    def testTimeoutException(self):
346        """Test TimeoutException."""
347        @utils.TimeoutException(1, "should time out")
348        def functionThatWillTimeOut():
349            """Test decorator of @utils.TimeoutException should timeout."""
350            time.sleep(5)
351
352        self.assertRaises(errors.FunctionTimeoutError,
353                          functionThatWillTimeOut)
354
355
356    def testTimeoutExceptionNoTimeout(self):
357        """Test No TimeoutException."""
358        @utils.TimeoutException(5, "shouldn't time out")
359        def functionThatShouldNotTimeout():
360            """Test decorator of @utils.TimeoutException shouldn't timeout."""
361            return None
362        try:
363            functionThatShouldNotTimeout()
364        except errors.FunctionTimeoutError:
365            self.fail("shouldn't timeout")
366
367    def testAutoConnectCreateSSHTunnelFail(self):
368        """Test auto connect."""
369        fake_ip_addr = "1.1.1.1"
370        fake_rsa_key_file = "/tmp/rsa_file"
371        fake_target_vnc_port = 8888
372        target_adb_port = 9999
373        ssh_user = "fake_user"
374        call_side_effect = subprocess.CalledProcessError(123, "fake",
375                                                         "fake error")
376        result = utils.ForwardedPorts(vnc_port=None, adb_port=None)
377        self.Patch(subprocess, "check_call", side_effect=call_side_effect)
378        self.assertEqual(result, utils.AutoConnect(fake_ip_addr,
379                                                   fake_rsa_key_file,
380                                                   fake_target_vnc_port,
381                                                   target_adb_port,
382                                                   ssh_user))
383
384    # pylint: disable=protected-access,no-member
385    def testExtraArgsSSHTunnel(self):
386        """Test extra args will be the same with expanded args."""
387        fake_ip_addr = "1.1.1.1"
388        fake_rsa_key_file = "/tmp/rsa_file"
389        fake_target_vnc_port = 8888
390        target_adb_port = 9999
391        ssh_user = "fake_user"
392        fake_port = 12345
393        self.Patch(utils, "PickFreePort", return_value=fake_port)
394        self.Patch(utils, "_ExecuteCommand")
395        self.Patch(subprocess, "check_call", return_value=True)
396        extra_args_ssh_tunnel = "-o command='shell %s %h' -o command1='ls -la'"
397        utils.AutoConnect(ip_addr=fake_ip_addr,
398                          rsa_key_file=fake_rsa_key_file,
399                          target_vnc_port=fake_target_vnc_port,
400                          target_adb_port=target_adb_port,
401                          ssh_user=ssh_user,
402                          client_adb_port=fake_port,
403                          extra_args_ssh_tunnel=extra_args_ssh_tunnel)
404        args_list = ["-i", "/tmp/rsa_file",
405                     "-o", "UserKnownHostsFile=/dev/null",
406                     "-o", "StrictHostKeyChecking=no",
407                     "-L", "12345:127.0.0.1:9999",
408                     "-L", "12345:127.0.0.1:8888",
409                     "-N", "-f", "-l", "fake_user", "1.1.1.1",
410                     "-o", "command=shell %s %h",
411                     "-o", "command1=ls -la"]
412        first_call_args = utils._ExecuteCommand.call_args_list[0][0]
413        self.assertEqual(first_call_args[1], args_list)
414
415    # pylint: disable=protected-access,no-member
416    def testEstablishWebRTCSshTunnel(self):
417        """Test establish WebRTC ssh tunnel."""
418        fake_ip_addr = "1.1.1.1"
419        fake_rsa_key_file = "/tmp/rsa_file"
420        ssh_user = "fake_user"
421        self.Patch(utils, "ReleasePort")
422        self.Patch(utils, "_ExecuteCommand")
423        self.Patch(subprocess, "check_call", return_value=True)
424        extra_args_ssh_tunnel = "-o command='shell %s %h' -o command1='ls -la'"
425        utils.EstablishWebRTCSshTunnel(
426            ip_addr=fake_ip_addr, rsa_key_file=fake_rsa_key_file,
427            ssh_user=ssh_user, extra_args_ssh_tunnel=None)
428        args_list = ["-i", "/tmp/rsa_file",
429                     "-o", "UserKnownHostsFile=/dev/null",
430                     "-o", "StrictHostKeyChecking=no",
431                     "-L", "8443:127.0.0.1:8443",
432                     "-L", "15550:127.0.0.1:15550",
433                     "-L", "15551:127.0.0.1:15551",
434                     "-N", "-f", "-l", "fake_user", "1.1.1.1"]
435        first_call_args = utils._ExecuteCommand.call_args_list[0][0]
436        self.assertEqual(first_call_args[1], args_list)
437
438        extra_args_ssh_tunnel = "-o command='shell %s %h'"
439        utils.EstablishWebRTCSshTunnel(
440            ip_addr=fake_ip_addr, rsa_key_file=fake_rsa_key_file,
441            ssh_user=ssh_user, extra_args_ssh_tunnel=extra_args_ssh_tunnel)
442        args_list_with_extra_args = ["-i", "/tmp/rsa_file",
443                                     "-o", "UserKnownHostsFile=/dev/null",
444                                     "-o", "StrictHostKeyChecking=no",
445                                     "-L", "8443:127.0.0.1:8443",
446                                     "-L", "15550:127.0.0.1:15550",
447                                     "-L", "15551:127.0.0.1:15551",
448                                     "-N", "-f", "-l", "fake_user", "1.1.1.1",
449                                     "-o", "command=shell %s %h"]
450        first_call_args = utils._ExecuteCommand.call_args_list[1][0]
451        self.assertEqual(first_call_args[1], args_list_with_extra_args)
452
453    # pylint: disable=protected-access, no-member
454    def testCleanupSSVncviwer(self):
455        """test cleanup ssvnc viewer."""
456        fake_vnc_port = 9999
457        fake_ss_vncviewer_pattern = utils._SSVNC_VIEWER_PATTERN % {
458            "vnc_port": fake_vnc_port}
459        self.Patch(utils, "IsCommandRunning", return_value=True)
460        self.Patch(subprocess, "check_call", return_value=True)
461        utils.CleanupSSVncviewer(fake_vnc_port)
462        subprocess.check_call.assert_called_with(["pkill", "-9", "-f", fake_ss_vncviewer_pattern])
463
464        subprocess.check_call.call_count = 0
465        self.Patch(utils, "IsCommandRunning", return_value=False)
466        utils.CleanupSSVncviewer(fake_vnc_port)
467        subprocess.check_call.assert_not_called()
468
469    def testLaunchBrowserFromReport(self):
470        """test launch browser from report."""
471        self.Patch(webbrowser, "open_new_tab")
472        fake_report = mock.MagicMock(data={})
473
474        # test remote instance
475        self.Patch(os.environ, "get", return_value=True)
476        fake_report.data = {
477            "devices": [{"instance_name": "remote_cf_instance_name",
478                         "ip": "192.168.1.1",},],}
479
480        utils.LaunchBrowserFromReport(fake_report)
481        webbrowser.open_new_tab.assert_called_once_with("https://localhost:8443/?use_tcp=true")
482        webbrowser.open_new_tab.call_count = 0
483
484        # test local instance
485        fake_report.data = {
486            "devices": [{"instance_name": "local-instance1",
487                         "ip": "127.0.0.1:6250",},],}
488        utils.LaunchBrowserFromReport(fake_report)
489        webbrowser.open_new_tab.assert_called_once_with("https://localhost:8443/?use_tcp=true")
490        webbrowser.open_new_tab.call_count = 0
491
492        # verify terminal can't support launch webbrowser.
493        self.Patch(os.environ, "get", return_value=False)
494        utils.LaunchBrowserFromReport(fake_report)
495        self.assertEqual(webbrowser.open_new_tab.call_count, 0)
496
497
498if __name__ == "__main__":
499    unittest.main()
500