#!/usr/bin/env python
#
# Copyright (c) 2008 Canonical
#
# Written by Marc Tardif <marc@interunion.ca>
#
# This file is part of HWTest.
#
# HWTest is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# HWTest is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with HWTest.  If not, see <http://www.gnu.org/licenses/>.
#
import optparse

import os
import sys

lib_dir = os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, lib_dir)


def find_tests(testpaths=()):
    """Find all test paths, or test paths contained in the provided sequence.

    @param testpaths: If provided, only tests in the given sequence will
                      be considered.  If not provided, all tests are
                      considered.
    @return: (unittests, doctests) tuple, with lists of unittests and
             doctests found, respectively.
    """
    topdir = os.path.abspath(os.path.dirname(__file__))
    testpaths = set(testpaths)
    unittests = []
    doctests = []
    for root, dirnames, filenames in os.walk(topdir):
        for filename in filenames:
            filepath = os.path.join(root, filename)
            relpath = filepath[len(topdir)+1:]

            if (filename == "__init__.py"
               or filename.endswith(".pyc")
               or "/tests/" not in relpath):
                # Skip non-tests.
                continue

            if testpaths:
                # Skip any tests not in testpaths.
                for testpath in testpaths:
                    if relpath.startswith(testpath):
                        break
                else:
                    continue

            if filename.endswith(".py"):
                unittests.append(relpath)
            elif filename.endswith(".txt"):
                doctests.append(relpath)

    return unittests, doctests

def parse_sys_argv():
    """Extract any arguments not starting with '-' from sys.argv."""
    testpaths = []
    for i in range(len(sys.argv)-1,0,-1):
        arg = sys.argv[i]
        if not arg.startswith("-"):
            testpaths.append(arg)
            del sys.argv[i]
    return testpaths

def test_with_unittest():
    import unittest
    import doctest

    usage = "test [options] [<test filename>, ...]"

    parser = optparse.OptionParser(usage=usage)

    parser.add_option('--verbose', action='store_true')
    opts, args = parser.parse_args()
    opts.args = args

    runner = unittest.TextTestRunner()

    if opts.verbose:
        runner.verbosity = 2

    loader = unittest.TestLoader()
    unittests, doctests = find_tests(args)

    class Summary:
        def __init__(self):
            self.total_failures = 0
            self.total_errors = 0
            self.total_tests = 0
        def __call__(self, tests, failures, errors):
            self.total_tests += tests
            self.total_failures += failures
            self.total_errors += errors
            print "(tests=%d, failures=%d, errors=%d)" % \
                  (tests, failures, errors)

    unittest_summary = Summary()
    doctest_summary = Summary()

    if unittests:
        print "Running unittests..."
        for relpath in unittests:
            print "[%s]" % relpath
            modpath = relpath.replace('/', '.')[:-3]
            module = __import__(modpath, None, None, [""])
            test = loader.loadTestsFromModule(module)
            result = runner.run(test)
            unittest_summary(test.countTestCases(),
                             len(result.failures), len(result.errors))
            print

    if doctests:
        print "Running doctests..."
        doctest_flags = doctest.ELLIPSIS
        for relpath in doctests:
            print "[%s]" % relpath
            failures, total = doctest.testfile(relpath,
                                               optionflags=doctest_flags)
            doctest_summary(total, failures, 0)
            print

    print "Total test cases: %d" % unittest_summary.total_tests
    print "Total doctests: %d" % doctest_summary.total_tests
    print "Total failures: %d" % (unittest_summary.total_failures +
                                  doctest_summary.total_failures)
    print "Total errors: %d" % (unittest_summary.total_errors +
                                doctest_summary.total_errors)

    failed = bool(unittest_summary.total_failures or
                  unittest_summary.total_errors or
                  doctest_summary.total_failures or
                  doctest_summary.total_errors)

    sys.exit(failed)

if __name__ == "__main__":
    runner = os.environ.get("HWTEST_TEST_RUNNER")
    if not runner:
        runner = "unittest"
    runner_func = globals().get("test_with_%s" % runner.replace(".", "_"))
    if not runner_func:
        sys.exit("Test runner not found: %s" % runner)
    runner_func()

# vim:ts=4:sw=4:et
