1#!/usr/bin/env python3
2# -*- coding:utf-8 -*-
3# Copyright 2016 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Unittests for the shell module."""
18
19from __future__ import print_function
20
21import difflib
22import os
23import sys
24import unittest
25
26_path = os.path.realpath(__file__ + '/../..')
27if sys.path[0] != _path:
28    sys.path.insert(0, _path)
29del _path
30
31# We have to import our local modules after the sys.path tweak.  We can't use
32# relative imports because this is an executable program, not a module.
33# pylint: disable=wrong-import-position
34import rh.shell
35
36
37class DiffTestCase(unittest.TestCase):
38    """Helper that includes diff output when failing."""
39
40    def setUp(self):
41        self.differ = difflib.Differ()
42
43    def _assertEqual(self, func, test_input, test_output, result):
44        """Like assertEqual but with built in diff support."""
45        diff = '\n'.join(list(self.differ.compare([test_output], [result])))
46        msg = ('Expected %s to translate %r to %r, but got %r\n%s' %
47               (func, test_input, test_output, result, diff))
48        self.assertEqual(test_output, result, msg)
49
50    def _testData(self, functor, tests, check_type=True):
51        """Process a dict of test data."""
52        for test_output, test_input in tests.items():
53            result = functor(test_input)
54            self._assertEqual(functor.__name__, test_input, test_output, result)
55
56            if check_type:
57                # Also make sure the result is a string, otherwise the %r
58                # output will include a "u" prefix and that is not good for
59                # logging.
60                self.assertEqual(type(test_output), str)
61
62
63class ShellQuoteTest(DiffTestCase):
64    """Test the shell_quote & shell_unquote functions."""
65
66    def testShellQuote(self):
67        """Basic ShellQuote tests."""
68        # Dict of expected output strings to input lists.
69        tests_quote = {
70            "''": '',
71            'a': u'a',
72            "'a b c'": u'a b c',
73            "'a\tb'": 'a\tb',
74            "'/a$file'": '/a$file',
75            "'/a#file'": '/a#file',
76            """'b"c'""": 'b"c',
77            "'a@()b'": 'a@()b',
78            'j%k': 'j%k',
79            r'''"s'a\$va\\rs"''': r"s'a$va\rs",
80            r'''"\\'\\\""''': r'''\'\"''',
81            r'''"'\\\$"''': r"""'\$""",
82        }
83
84        # Expected input output specific to ShellUnquote.  This string cannot
85        # be produced by ShellQuote but is still a valid bash escaped string.
86        tests_unquote = {
87            r'''\$''': r'''"\\$"''',
88        }
89
90        def aux(s):
91            return rh.shell.shell_unquote(rh.shell.shell_quote(s))
92
93        self._testData(rh.shell.shell_quote, tests_quote)
94        self._testData(rh.shell.shell_unquote, tests_unquote)
95
96        # Test that the operations are reversible.
97        self._testData(aux, {k: k for k in tests_quote.values()}, False)
98        self._testData(aux, {k: k for k in tests_quote}, False)
99
100
101class CmdToStrTest(DiffTestCase):
102    """Test the cmd_to_str function."""
103
104    def testCmdToStr(self):
105        # Dict of expected output strings to input lists.
106        tests = {
107            r"a b": ['a', 'b'],
108            r"'a b' c": ['a b', 'c'],
109            r'''a "b'c"''': ['a', "b'c"],
110            r'''a "/'\$b" 'a b c' "xy'z"''':
111                [u'a', "/'$b", 'a b c', "xy'z"],
112            '': [],
113        }
114        self._testData(rh.shell.cmd_to_str, tests)
115
116
117class BooleanShellTest(unittest.TestCase):
118    """Test the boolean_shell_value function."""
119
120    def testFull(self):
121        """Verify nputs work as expected"""
122        for v in (None,):
123            self.assertTrue(rh.shell.boolean_shell_value(v, True))
124            self.assertFalse(rh.shell.boolean_shell_value(v, False))
125
126        for v in (1234, '', 'akldjsf', '"'):
127            self.assertRaises(ValueError, rh.shell.boolean_shell_value, v, True)
128
129        for v in ('yes', 'YES', 'YeS', 'y', 'Y', '1', 'true', 'True', 'TRUE',):
130            self.assertTrue(rh.shell.boolean_shell_value(v, True))
131            self.assertTrue(rh.shell.boolean_shell_value(v, False))
132
133        for v in ('no', 'NO', 'nO', 'n', 'N', '0', 'false', 'False', 'FALSE',):
134            self.assertFalse(rh.shell.boolean_shell_value(v, True))
135            self.assertFalse(rh.shell.boolean_shell_value(v, False))
136
137
138if __name__ == '__main__':
139    unittest.main()
140