1#!/usr/bin/env python3
2#
3# Copyright 2018 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Main entrypoint for all of atest's unittest."""
18
19import logging
20import os
21import sys
22import unittest
23
24from importlib import import_module
25
26import atest_utils
27
28COVERAGE = 'coverage'
29RUN_COVERAGE = COVERAGE in sys.argv
30SHOW_MISSING = '--show-missing' in sys.argv
31# Setup logging to be silent so unittests can pass through TF.
32logging.disable(logging.ERROR)
33
34def get_test_modules():
35    """Returns a list of testable modules.
36
37    Finds all the test files (*_unittest.py) and get their no-absolute
38    path (internal/lib/utils_test.py) and translate it to an import path and
39    strip the py ext (internal.lib.utils_test).
40
41    Returns:
42        List of strings (the testable module import path).
43    """
44    testable_modules = []
45    base_path = os.path.dirname(os.path.realpath(__file__))
46
47    for dirpath, _, files in os.walk(base_path):
48        for f in files:
49            if f.endswith("_unittest.py"):
50                # Now transform it into a no-absolute import path.
51                full_file_path = os.path.join(dirpath, f)
52                rel_file_path = os.path.relpath(full_file_path, base_path)
53                rel_file_path, _ = os.path.splitext(rel_file_path)
54                rel_file_path = rel_file_path.replace(os.sep, ".")
55                testable_modules.append(rel_file_path)
56
57    return testable_modules
58
59def run_test_modules(test_modules):
60    """Main method of running unit tests.
61
62    Args:
63        test_modules; a list of module names.
64
65    Returns:
66        result: a namespace of unittest result.
67    """
68    for mod in test_modules:
69        import_module(mod)
70
71    loader = unittest.defaultTestLoader
72    test_suite = loader.loadTestsFromNames(test_modules)
73    runner = unittest.TextTestRunner(verbosity=2)
74    return runner.run(test_suite)
75
76# pylint: disable=import-outside-toplevel
77def main(run_coverage=False, show_missing=False):
78    """Main unittest entry.
79
80    Args:
81        cov_args: A list of coverage arguments.
82
83    Returns:
84        0 if success. None-zero if fails.
85    """
86    if not all((run_coverage, atest_utils.has_python_module(COVERAGE))):
87        result = run_test_modules(get_test_modules())
88        if not result.wasSuccessful():
89            sys.exit(not result.wasSuccessful())
90        sys.exit(0)
91
92    from coverage import coverage
93    # The cover_pylib=False ignores only std libs; therefore, these 3rd-party
94    # libs must be omitted before creating coverage class.
95    ignore_libs = ['*/__init__.py',
96                   '*dist-packages/*.py',
97                   '*site-packages/*.py']
98    cov = coverage(omit=ignore_libs)
99    cov.erase()
100    cov.start()
101    result = run_test_modules(get_test_modules())
102    if not result.wasSuccessful():
103        cov.erase()
104        sys.exit(not result.wasSuccessful())
105    cov.stop()
106    cov.save()
107    cov.report(show_missing=show_missing)
108    cov.html_report()
109
110
111if __name__ == '__main__':
112    if len(sys.argv) > 1:
113        main(RUN_COVERAGE, SHOW_MISSING)
114    else:
115        main()
116