1#!/usr/bin/env python 2# 3# Copyright 2016 - The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16"""Tests for acloud.internal.lib.utils.""" 17 18import errno 19import getpass 20import grp 21import os 22import shutil 23import subprocess 24import tempfile 25import time 26import webbrowser 27 28import unittest 29import six 30import mock 31 32from acloud import errors 33from acloud.internal.lib import driver_test_lib 34from acloud.internal.lib import utils 35 36 37# Tkinter may not be supported so mock it out. 38try: 39 import Tkinter 40except ImportError: 41 Tkinter = mock.Mock() 42 43 44class FakeTkinter(object): 45 """Fake implementation of Tkinter.Tk()""" 46 47 def __init__(self, width=None, height=None): 48 self.width = width 49 self.height = height 50 51 # pylint: disable=invalid-name 52 def winfo_screenheight(self): 53 """Return the screen height.""" 54 return self.height 55 56 # pylint: disable=invalid-name 57 def winfo_screenwidth(self): 58 """Return the screen width.""" 59 return self.width 60 61 62# pylint: disable=too-many-public-methods 63class UtilsTest(driver_test_lib.BaseDriverTest): 64 """Test Utils.""" 65 66 def TestTempDirSuccess(self): 67 """Test create a temp dir.""" 68 self.Patch(os, "chmod") 69 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 70 self.Patch(shutil, "rmtree") 71 with utils.TempDir(): 72 pass 73 # Verify. 74 tempfile.mkdtemp.assert_called_once() # pylint: disable=no-member 75 shutil.rmtree.assert_called_with("/tmp/tempdir") # pylint: disable=no-member 76 77 def TestTempDirExceptionRaised(self): 78 """Test create a temp dir and exception is raised within with-clause.""" 79 self.Patch(os, "chmod") 80 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 81 self.Patch(shutil, "rmtree") 82 83 class ExpectedException(Exception): 84 """Expected exception.""" 85 86 def _Call(): 87 with utils.TempDir(): 88 raise ExpectedException("Expected exception.") 89 90 # Verify. ExpectedException should be raised. 91 self.assertRaises(ExpectedException, _Call) 92 tempfile.mkdtemp.assert_called_once() # pylint: disable=no-member 93 shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member 94 95 def testTempDirWhenDeleteTempDirNoLongerExist(self): # pylint: disable=invalid-name 96 """Test create a temp dir and dir no longer exists during deletion.""" 97 self.Patch(os, "chmod") 98 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 99 expected_error = EnvironmentError() 100 expected_error.errno = errno.ENOENT 101 self.Patch(shutil, "rmtree", side_effect=expected_error) 102 103 def _Call(): 104 with utils.TempDir(): 105 pass 106 107 # Verify no exception should be raised when rmtree raises 108 # EnvironmentError with errno.ENOENT, i.e. 109 # directory no longer exists. 110 _Call() 111 tempfile.mkdtemp.assert_called_once() #pylint: disable=no-member 112 shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member 113 114 def testTempDirWhenDeleteEncounterError(self): 115 """Test create a temp dir and encoutered error during deletion.""" 116 self.Patch(os, "chmod") 117 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 118 expected_error = OSError("Expected OS Error") 119 self.Patch(shutil, "rmtree", side_effect=expected_error) 120 121 def _Call(): 122 with utils.TempDir(): 123 pass 124 125 # Verify OSError should be raised. 126 self.assertRaises(OSError, _Call) 127 tempfile.mkdtemp.assert_called_once() #pylint: disable=no-member 128 shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member 129 130 def testTempDirOrininalErrorRaised(self): 131 """Test original error is raised even if tmp dir deletion failed.""" 132 self.Patch(os, "chmod") 133 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 134 expected_error = OSError("Expected OS Error") 135 self.Patch(shutil, "rmtree", side_effect=expected_error) 136 137 class ExpectedException(Exception): 138 """Expected exception.""" 139 140 def _Call(): 141 with utils.TempDir(): 142 raise ExpectedException("Expected Exception") 143 144 # Verify. 145 # ExpectedException should be raised, and OSError 146 # should not be raised. 147 self.assertRaises(ExpectedException, _Call) 148 tempfile.mkdtemp.assert_called_once() #pylint: disable=no-member 149 shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member 150 151 def testCreateSshKeyPairKeyAlreadyExists(self): #pylint: disable=invalid-name 152 """Test when the key pair already exists.""" 153 public_key = "/fake/public_key" 154 private_key = "/fake/private_key" 155 self.Patch(os.path, "exists", side_effect=[True, True]) 156 self.Patch(subprocess, "check_call") 157 self.Patch(os, "makedirs", return_value=True) 158 utils.CreateSshKeyPairIfNotExist(private_key, public_key) 159 self.assertEqual(subprocess.check_call.call_count, 0) #pylint: disable=no-member 160 161 def testCreateSshKeyPairKeyAreCreated(self): 162 """Test when the key pair created.""" 163 public_key = "/fake/public_key" 164 private_key = "/fake/private_key" 165 self.Patch(os.path, "exists", return_value=False) 166 self.Patch(os, "makedirs", return_value=True) 167 self.Patch(subprocess, "check_call") 168 self.Patch(os, "rename") 169 utils.CreateSshKeyPairIfNotExist(private_key, public_key) 170 self.assertEqual(subprocess.check_call.call_count, 1) #pylint: disable=no-member 171 subprocess.check_call.assert_called_with( #pylint: disable=no-member 172 utils.SSH_KEYGEN_CMD + 173 ["-C", getpass.getuser(), "-f", private_key], 174 stdout=mock.ANY, 175 stderr=mock.ANY) 176 177 def testCreatePublicKeyAreCreated(self): 178 """Test when the PublicKey created.""" 179 public_key = "/fake/public_key" 180 private_key = "/fake/private_key" 181 self.Patch(os.path, "exists", side_effect=[False, True, True]) 182 self.Patch(os, "makedirs", return_value=True) 183 mock_open = mock.mock_open(read_data=public_key) 184 self.Patch(subprocess, "check_output") 185 self.Patch(os, "rename") 186 with mock.patch.object(six.moves.builtins, "open", mock_open): 187 utils.CreateSshKeyPairIfNotExist(private_key, public_key) 188 self.assertEqual(subprocess.check_output.call_count, 1) #pylint: disable=no-member 189 subprocess.check_output.assert_called_with( #pylint: disable=no-member 190 utils.SSH_KEYGEN_PUB_CMD +["-f", private_key]) 191 192 def TestRetryOnException(self): 193 """Test Retry.""" 194 195 def _IsValueError(exc): 196 return isinstance(exc, ValueError) 197 198 num_retry = 5 199 200 @utils.RetryOnException(_IsValueError, num_retry) 201 def _RaiseAndRetry(sentinel): 202 sentinel.alert() 203 raise ValueError("Fake error.") 204 205 sentinel = mock.MagicMock() 206 self.assertRaises(ValueError, _RaiseAndRetry, sentinel) 207 self.assertEqual(1 + num_retry, sentinel.alert.call_count) 208 209 def testRetryExceptionType(self): 210 """Test RetryExceptionType function.""" 211 212 def _RaiseAndRetry(sentinel): 213 sentinel.alert() 214 raise ValueError("Fake error.") 215 216 num_retry = 5 217 sentinel = mock.MagicMock() 218 self.assertRaises( 219 ValueError, 220 utils.RetryExceptionType, (KeyError, ValueError), 221 num_retry, 222 _RaiseAndRetry, 223 0, # sleep_multiplier 224 1, # retry_backoff_factor 225 sentinel=sentinel) 226 self.assertEqual(1 + num_retry, sentinel.alert.call_count) 227 228 def testRetry(self): 229 """Test Retry.""" 230 mock_sleep = self.Patch(time, "sleep") 231 232 def _RaiseAndRetry(sentinel): 233 sentinel.alert() 234 raise ValueError("Fake error.") 235 236 num_retry = 5 237 sentinel = mock.MagicMock() 238 self.assertRaises( 239 ValueError, 240 utils.RetryExceptionType, (ValueError, KeyError), 241 num_retry, 242 _RaiseAndRetry, 243 1, # sleep_multiplier 244 2, # retry_backoff_factor 245 sentinel=sentinel) 246 247 self.assertEqual(1 + num_retry, sentinel.alert.call_count) 248 mock_sleep.assert_has_calls( 249 [ 250 mock.call(1), 251 mock.call(2), 252 mock.call(4), 253 mock.call(8), 254 mock.call(16) 255 ]) 256 257 @mock.patch.object(six.moves, "input") 258 def testGetAnswerFromList(self, mock_raw_input): 259 """Test GetAnswerFromList.""" 260 answer_list = ["image1.zip", "image2.zip", "image3.zip"] 261 mock_raw_input.return_value = 0 262 with self.assertRaises(SystemExit): 263 utils.GetAnswerFromList(answer_list) 264 mock_raw_input.side_effect = [1, 2, 3, 4] 265 self.assertEqual(utils.GetAnswerFromList(answer_list), 266 ["image1.zip"]) 267 self.assertEqual(utils.GetAnswerFromList(answer_list), 268 ["image2.zip"]) 269 self.assertEqual(utils.GetAnswerFromList(answer_list), 270 ["image3.zip"]) 271 self.assertEqual(utils.GetAnswerFromList(answer_list, 272 enable_choose_all=True), 273 answer_list) 274 275 @unittest.skipIf(isinstance(Tkinter, mock.Mock), "Tkinter mocked out, test case not needed.") 276 @mock.patch.object(Tkinter, "Tk") 277 def testCalculateVNCScreenRatio(self, mock_tk): 278 """Test Calculating the scale ratio of VNC display.""" 279 # Get scale-down ratio if screen height is smaller than AVD height. 280 mock_tk.return_value = FakeTkinter(height=800, width=1200) 281 avd_h = 1920 282 avd_w = 1080 283 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.4) 284 285 # Get scale-down ratio if screen width is smaller than AVD width. 286 mock_tk.return_value = FakeTkinter(height=800, width=1200) 287 avd_h = 900 288 avd_w = 1920 289 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.6) 290 291 # Scale ratio = 1 if screen is larger than AVD. 292 mock_tk.return_value = FakeTkinter(height=1080, width=1920) 293 avd_h = 800 294 avd_w = 1280 295 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 1) 296 297 # Get the scale if ratio of width is smaller than the 298 # ratio of height. 299 mock_tk.return_value = FakeTkinter(height=1200, width=800) 300 avd_h = 1920 301 avd_w = 1080 302 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.6) 303 304 # pylint: disable=protected-access 305 def testCheckUserInGroups(self): 306 """Test CheckUserInGroups.""" 307 self.Patch(os, "getgroups", return_value=[1, 2, 3]) 308 gr1 = mock.MagicMock() 309 gr1.gr_name = "fake_gr_1" 310 gr2 = mock.MagicMock() 311 gr2.gr_name = "fake_gr_2" 312 gr3 = mock.MagicMock() 313 gr3.gr_name = "fake_gr_3" 314 self.Patch(grp, "getgrgid", side_effect=[gr1, gr2, gr3]) 315 316 # User in all required groups should return true. 317 self.assertTrue( 318 utils.CheckUserInGroups( 319 ["fake_gr_1", "fake_gr_2"])) 320 321 # User not in all required groups should return False. 322 self.Patch(grp, "getgrgid", side_effect=[gr1, gr2, gr3]) 323 self.assertFalse( 324 utils.CheckUserInGroups( 325 ["fake_gr_1", "fake_gr_4"])) 326 327 @mock.patch.object(utils, "CheckUserInGroups") 328 def testAddUserGroupsToCmd(self, mock_user_group): 329 """Test AddUserGroupsToCmd.""" 330 command = "test_command" 331 groups = ["group1", "group2"] 332 # Don't add user group in command 333 mock_user_group.return_value = True 334 expected_value = "test_command" 335 self.assertEqual(expected_value, utils.AddUserGroupsToCmd(command, 336 groups)) 337 338 # Add user group in command 339 mock_user_group.return_value = False 340 expected_value = "sg group1 <<EOF\nsg group2\ntest_command\nEOF" 341 self.assertEqual(expected_value, utils.AddUserGroupsToCmd(command, 342 groups)) 343 344 # pylint: disable=invalid-name 345 def testTimeoutException(self): 346 """Test TimeoutException.""" 347 @utils.TimeoutException(1, "should time out") 348 def functionThatWillTimeOut(): 349 """Test decorator of @utils.TimeoutException should timeout.""" 350 time.sleep(5) 351 352 self.assertRaises(errors.FunctionTimeoutError, 353 functionThatWillTimeOut) 354 355 356 def testTimeoutExceptionNoTimeout(self): 357 """Test No TimeoutException.""" 358 @utils.TimeoutException(5, "shouldn't time out") 359 def functionThatShouldNotTimeout(): 360 """Test decorator of @utils.TimeoutException shouldn't timeout.""" 361 return None 362 try: 363 functionThatShouldNotTimeout() 364 except errors.FunctionTimeoutError: 365 self.fail("shouldn't timeout") 366 367 def testAutoConnectCreateSSHTunnelFail(self): 368 """Test auto connect.""" 369 fake_ip_addr = "1.1.1.1" 370 fake_rsa_key_file = "/tmp/rsa_file" 371 fake_target_vnc_port = 8888 372 target_adb_port = 9999 373 ssh_user = "fake_user" 374 call_side_effect = subprocess.CalledProcessError(123, "fake", 375 "fake error") 376 result = utils.ForwardedPorts(vnc_port=None, adb_port=None) 377 self.Patch(subprocess, "check_call", side_effect=call_side_effect) 378 self.assertEqual(result, utils.AutoConnect(fake_ip_addr, 379 fake_rsa_key_file, 380 fake_target_vnc_port, 381 target_adb_port, 382 ssh_user)) 383 384 # pylint: disable=protected-access,no-member 385 def testExtraArgsSSHTunnel(self): 386 """Test extra args will be the same with expanded args.""" 387 fake_ip_addr = "1.1.1.1" 388 fake_rsa_key_file = "/tmp/rsa_file" 389 fake_target_vnc_port = 8888 390 target_adb_port = 9999 391 ssh_user = "fake_user" 392 fake_port = 12345 393 self.Patch(utils, "PickFreePort", return_value=fake_port) 394 self.Patch(utils, "_ExecuteCommand") 395 self.Patch(subprocess, "check_call", return_value=True) 396 extra_args_ssh_tunnel = "-o command='shell %s %h' -o command1='ls -la'" 397 utils.AutoConnect(ip_addr=fake_ip_addr, 398 rsa_key_file=fake_rsa_key_file, 399 target_vnc_port=fake_target_vnc_port, 400 target_adb_port=target_adb_port, 401 ssh_user=ssh_user, 402 client_adb_port=fake_port, 403 extra_args_ssh_tunnel=extra_args_ssh_tunnel) 404 args_list = ["-i", "/tmp/rsa_file", 405 "-o", "UserKnownHostsFile=/dev/null", 406 "-o", "StrictHostKeyChecking=no", 407 "-L", "12345:127.0.0.1:9999", 408 "-L", "12345:127.0.0.1:8888", 409 "-N", "-f", "-l", "fake_user", "1.1.1.1", 410 "-o", "command=shell %s %h", 411 "-o", "command1=ls -la"] 412 first_call_args = utils._ExecuteCommand.call_args_list[0][0] 413 self.assertEqual(first_call_args[1], args_list) 414 415 # pylint: disable=protected-access,no-member 416 def testEstablishWebRTCSshTunnel(self): 417 """Test establish WebRTC ssh tunnel.""" 418 fake_ip_addr = "1.1.1.1" 419 fake_rsa_key_file = "/tmp/rsa_file" 420 ssh_user = "fake_user" 421 self.Patch(utils, "ReleasePort") 422 self.Patch(utils, "_ExecuteCommand") 423 self.Patch(subprocess, "check_call", return_value=True) 424 extra_args_ssh_tunnel = "-o command='shell %s %h' -o command1='ls -la'" 425 utils.EstablishWebRTCSshTunnel( 426 ip_addr=fake_ip_addr, rsa_key_file=fake_rsa_key_file, 427 ssh_user=ssh_user, extra_args_ssh_tunnel=None) 428 args_list = ["-i", "/tmp/rsa_file", 429 "-o", "UserKnownHostsFile=/dev/null", 430 "-o", "StrictHostKeyChecking=no", 431 "-L", "8443:127.0.0.1:8443", 432 "-L", "15550:127.0.0.1:15550", 433 "-L", "15551:127.0.0.1:15551", 434 "-N", "-f", "-l", "fake_user", "1.1.1.1"] 435 first_call_args = utils._ExecuteCommand.call_args_list[0][0] 436 self.assertEqual(first_call_args[1], args_list) 437 438 extra_args_ssh_tunnel = "-o command='shell %s %h'" 439 utils.EstablishWebRTCSshTunnel( 440 ip_addr=fake_ip_addr, rsa_key_file=fake_rsa_key_file, 441 ssh_user=ssh_user, extra_args_ssh_tunnel=extra_args_ssh_tunnel) 442 args_list_with_extra_args = ["-i", "/tmp/rsa_file", 443 "-o", "UserKnownHostsFile=/dev/null", 444 "-o", "StrictHostKeyChecking=no", 445 "-L", "8443:127.0.0.1:8443", 446 "-L", "15550:127.0.0.1:15550", 447 "-L", "15551:127.0.0.1:15551", 448 "-N", "-f", "-l", "fake_user", "1.1.1.1", 449 "-o", "command=shell %s %h"] 450 first_call_args = utils._ExecuteCommand.call_args_list[1][0] 451 self.assertEqual(first_call_args[1], args_list_with_extra_args) 452 453 # pylint: disable=protected-access, no-member 454 def testCleanupSSVncviwer(self): 455 """test cleanup ssvnc viewer.""" 456 fake_vnc_port = 9999 457 fake_ss_vncviewer_pattern = utils._SSVNC_VIEWER_PATTERN % { 458 "vnc_port": fake_vnc_port} 459 self.Patch(utils, "IsCommandRunning", return_value=True) 460 self.Patch(subprocess, "check_call", return_value=True) 461 utils.CleanupSSVncviewer(fake_vnc_port) 462 subprocess.check_call.assert_called_with(["pkill", "-9", "-f", fake_ss_vncviewer_pattern]) 463 464 subprocess.check_call.call_count = 0 465 self.Patch(utils, "IsCommandRunning", return_value=False) 466 utils.CleanupSSVncviewer(fake_vnc_port) 467 subprocess.check_call.assert_not_called() 468 469 def testLaunchBrowserFromReport(self): 470 """test launch browser from report.""" 471 self.Patch(webbrowser, "open_new_tab") 472 fake_report = mock.MagicMock(data={}) 473 474 # test remote instance 475 self.Patch(os.environ, "get", return_value=True) 476 fake_report.data = { 477 "devices": [{"instance_name": "remote_cf_instance_name", 478 "ip": "192.168.1.1",},],} 479 480 utils.LaunchBrowserFromReport(fake_report) 481 webbrowser.open_new_tab.assert_called_once_with("https://localhost:8443/?use_tcp=true") 482 webbrowser.open_new_tab.call_count = 0 483 484 # test local instance 485 fake_report.data = { 486 "devices": [{"instance_name": "local-instance1", 487 "ip": "127.0.0.1:6250",},],} 488 utils.LaunchBrowserFromReport(fake_report) 489 webbrowser.open_new_tab.assert_called_once_with("https://localhost:8443/?use_tcp=true") 490 webbrowser.open_new_tab.call_count = 0 491 492 # verify terminal can't support launch webbrowser. 493 self.Patch(os.environ, "get", return_value=False) 494 utils.LaunchBrowserFromReport(fake_report) 495 self.assertEqual(webbrowser.open_new_tab.call_count, 0) 496 497 498if __name__ == "__main__": 499 unittest.main() 500