xref: /openbsd-src/gnu/llvm/lldb/third_party/Python/module/unittest2/unittest2/loader.py (revision 061da546b983eb767bad15e67af1174fb0bcf31c)
1"""Loading unittests."""
2
3import functools
4import os
5import re
6import sys
7import traceback
8import types
9import unittest
10
11from fnmatch import fnmatch
12
13from unittest2 import case, suite, cmp_
14
15try:
16    from os.path import relpath
17except ImportError:
18    from unittest2.compatibility import relpath
19
20__unittest = True
21
22# what about .pyc or .pyo (etc)
23# we would need to avoid loading the same tests multiple times
24# from '.py', '.pyc' *and* '.pyo'
25VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
26
27
28def _make_failed_import_test(name, suiteClass):
29    message = 'Failed to import test module: %s' % name
30    if hasattr(traceback, 'format_exc'):
31        # Python 2.3 compatibility
32        # format_exc returns two frames of discover.py as well
33        message += '\n%s' % traceback.format_exc()
34    return _make_failed_test('ModuleImportFailure', name, ImportError(message),
35                             suiteClass)
36
37
38def _make_failed_load_tests(name, exception, suiteClass):
39    return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
40
41
42def _make_failed_test(classname, methodname, exception, suiteClass):
43    def testFailure(self):
44        raise exception
45    attrs = {methodname: testFailure}
46    TestClass = type(classname, (case.TestCase,), attrs)
47    return suiteClass((TestClass(methodname),))
48
49
50class TestLoader(unittest.TestLoader):
51    """
52    This class is responsible for loading tests according to various criteria
53    and returning them wrapped in a TestSuite
54    """
55
56    def __init__(self):
57        self.testMethodPrefix = 'test'
58        self.sortTestMethodsUsing = cmp_
59        self.suiteClass = suite.TestSuite
60        self._top_level_dir = None
61
62    def loadTestsFromTestCase(self, testCaseClass):
63        """Return a suite of all tests cases contained in testCaseClass"""
64        if issubclass(testCaseClass, suite.TestSuite):
65            raise TypeError("Test cases should not be derived from TestSuite."
66                            " Maybe you meant to derive from TestCase?")
67        testCaseNames = self.getTestCaseNames(testCaseClass)
68        if not testCaseNames and hasattr(testCaseClass, 'runTest'):
69            testCaseNames = ['runTest']
70        loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
71        return loaded_suite
72
73    def loadTestsFromModule(self, module, use_load_tests=True):
74        """Return a suite of all tests cases contained in the given module"""
75        tests = []
76        for name in dir(module):
77            obj = getattr(module, name)
78            if isinstance(obj, type) and issubclass(obj, unittest.TestCase):
79                tests.append(self.loadTestsFromTestCase(obj))
80
81        load_tests = getattr(module, 'load_tests', None)
82        tests = self.suiteClass(tests)
83        if use_load_tests and load_tests is not None:
84            try:
85                return load_tests(self, tests, None)
86            except Exception as e:
87                return _make_failed_load_tests(module.__name__, e,
88                                               self.suiteClass)
89        return tests
90
91    def loadTestsFromName(self, name, module=None):
92        """Return a suite of all tests cases given a string specifier.
93
94        The name may resolve either to a module, a test case class, a
95        test method within a test case class, or a callable object which
96        returns a TestCase or TestSuite instance.
97
98        The method optionally resolves the names relative to a given module.
99        """
100        parts = name.split('.')
101        if module is None:
102            parts_copy = parts[:]
103            while parts_copy:
104                try:
105                    module = __import__('.'.join(parts_copy))
106                    break
107                except ImportError:
108                    del parts_copy[-1]
109                    if not parts_copy:
110                        raise
111            parts = parts[1:]
112        obj = module
113        for part in parts:
114            parent, obj = obj, getattr(obj, part)
115
116        if isinstance(obj, types.ModuleType):
117            return self.loadTestsFromModule(obj)
118        elif isinstance(obj, type) and issubclass(obj, unittest.TestCase):
119            return self.loadTestsFromTestCase(obj)
120        elif (isinstance(obj, (types.MethodType, types.FunctionType)) and
121              isinstance(parent, type) and
122              issubclass(parent, case.TestCase)):
123            return self.suiteClass([parent(obj.__name__)])
124        elif isinstance(obj, unittest.TestSuite):
125            return obj
126        elif hasattr(obj, '__call__'):
127            test = obj()
128            if isinstance(test, unittest.TestSuite):
129                return test
130            elif isinstance(test, unittest.TestCase):
131                return self.suiteClass([test])
132            else:
133                raise TypeError("calling %s returned %s, not a test" %
134                                (obj, test))
135        else:
136            raise TypeError("don't know how to make test from: %s" % obj)
137
138    def loadTestsFromNames(self, names, module=None):
139        """Return a suite of all tests cases found using the given sequence
140        of string specifiers. See 'loadTestsFromName()'.
141        """
142        suites = [self.loadTestsFromName(name, module) for name in names]
143        return self.suiteClass(suites)
144
145    def getTestCaseNames(self, testCaseClass):
146        """Return a sorted sequence of method names found within testCaseClass
147        """
148        def isTestMethod(attrname, testCaseClass=testCaseClass,
149                         prefix=self.testMethodPrefix):
150            return attrname.startswith(prefix) and \
151                hasattr(getattr(testCaseClass, attrname), '__call__')
152        testFnNames = list(filter(isTestMethod, dir(testCaseClass)))
153        if self.sortTestMethodsUsing:
154            testFnNames.sort(
155                key=functools.cmp_to_key(
156                    self.sortTestMethodsUsing))
157        return testFnNames
158
159    def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
160        """Find and return all test modules from the specified start
161        directory, recursing into subdirectories to find them. Only test files
162        that match the pattern will be loaded. (Using shell style pattern
163        matching.)
164
165        All test modules must be importable from the top level of the project.
166        If the start directory is not the top level directory then the top
167        level directory must be specified separately.
168
169        If a test package name (directory with '__init__.py') matches the
170        pattern then the package will be checked for a 'load_tests' function. If
171        this exists then it will be called with loader, tests, pattern.
172
173        If load_tests exists then discovery does  *not* recurse into the package,
174        load_tests is responsible for loading all tests in the package.
175
176        The pattern is deliberately not stored as a loader attribute so that
177        packages can continue discovery themselves. top_level_dir is stored so
178        load_tests does not need to pass this argument in to loader.discover().
179        """
180        set_implicit_top = False
181        if top_level_dir is None and self._top_level_dir is not None:
182            # make top_level_dir optional if called from load_tests in a
183            # package
184            top_level_dir = self._top_level_dir
185        elif top_level_dir is None:
186            set_implicit_top = True
187            top_level_dir = start_dir
188
189        top_level_dir = os.path.abspath(top_level_dir)
190
191        if top_level_dir not in sys.path:
192            # all test modules must be importable from the top level directory
193            # should we *unconditionally* put the start directory in first
194            # in sys.path to minimise likelihood of conflicts between installed
195            # modules and development versions?
196            sys.path.insert(0, top_level_dir)
197        self._top_level_dir = top_level_dir
198
199        is_not_importable = False
200        if os.path.isdir(os.path.abspath(start_dir)):
201            start_dir = os.path.abspath(start_dir)
202            if start_dir != top_level_dir:
203                is_not_importable = not os.path.isfile(
204                    os.path.join(start_dir, '__init__.py'))
205        else:
206            # support for discovery from dotted module names
207            try:
208                __import__(start_dir)
209            except ImportError:
210                is_not_importable = True
211            else:
212                the_module = sys.modules[start_dir]
213                top_part = start_dir.split('.')[0]
214                start_dir = os.path.abspath(
215                    os.path.dirname((the_module.__file__)))
216                if set_implicit_top:
217                    self._top_level_dir = os.path.abspath(os.path.dirname(
218                        os.path.dirname(sys.modules[top_part].__file__)))
219                    sys.path.remove(top_level_dir)
220
221        if is_not_importable:
222            raise ImportError(
223                'Start directory is not importable: %r' %
224                start_dir)
225
226        tests = list(self._find_tests(start_dir, pattern))
227        return self.suiteClass(tests)
228
229    def _get_name_from_path(self, path):
230        path = os.path.splitext(os.path.normpath(path))[0]
231
232        _relpath = relpath(path, self._top_level_dir)
233        assert not os.path.isabs(_relpath), "Path must be within the project"
234        assert not _relpath.startswith('..'), "Path must be within the project"
235
236        name = _relpath.replace(os.path.sep, '.')
237        return name
238
239    def _get_module_from_name(self, name):
240        __import__(name)
241        return sys.modules[name]
242
243    def _match_path(self, path, full_path, pattern):
244        # override this method to use alternative matching strategy
245        return fnmatch(path, pattern)
246
247    def _find_tests(self, start_dir, pattern):
248        """Used by discovery. Yields test suites it loads."""
249        paths = os.listdir(start_dir)
250
251        for path in paths:
252            full_path = os.path.join(start_dir, path)
253            if os.path.isfile(full_path):
254                if not VALID_MODULE_NAME.match(path):
255                    # valid Python identifiers only
256                    continue
257                if not self._match_path(path, full_path, pattern):
258                    continue
259                # if the test file matches, load it
260                name = self._get_name_from_path(full_path)
261                try:
262                    module = self._get_module_from_name(name)
263                except:
264                    yield _make_failed_import_test(name, self.suiteClass)
265                else:
266                    mod_file = os.path.abspath(
267                        getattr(module, '__file__', full_path))
268                    realpath = os.path.splitext(mod_file)[0]
269                    fullpath_noext = os.path.splitext(full_path)[0]
270                    if realpath.lower() != fullpath_noext.lower():
271                        module_dir = os.path.dirname(realpath)
272                        mod_name = os.path.splitext(
273                            os.path.basename(full_path))[0]
274                        expected_dir = os.path.dirname(full_path)
275                        msg = (
276                            "%r module incorrectly imported from %r. Expected %r. "
277                            "Is this module globally installed?")
278                        raise ImportError(msg %
279                                          (mod_name, module_dir, expected_dir))
280                    yield self.loadTestsFromModule(module)
281            elif os.path.isdir(full_path):
282                if not os.path.isfile(os.path.join(full_path, '__init__.py')):
283                    continue
284
285                load_tests = None
286                tests = None
287                if fnmatch(path, pattern):
288                    # only check load_tests if the package directory itself
289                    # matches the filter
290                    name = self._get_name_from_path(full_path)
291                    package = self._get_module_from_name(name)
292                    load_tests = getattr(package, 'load_tests', None)
293                    tests = self.loadTestsFromModule(
294                        package, use_load_tests=False)
295
296                if load_tests is None:
297                    if tests is not None:
298                        # tests loaded from package file
299                        yield tests
300                    # recurse into the package
301                    for test in self._find_tests(full_path, pattern):
302                        yield test
303                else:
304                    try:
305                        yield load_tests(self, tests, pattern)
306                    except Exception as e:
307                        yield _make_failed_load_tests(package.__name__, e,
308                                                      self.suiteClass)
309
310defaultTestLoader = TestLoader()
311
312
313def _makeLoader(prefix, sortUsing, suiteClass=None):
314    loader = TestLoader()
315    loader.sortTestMethodsUsing = sortUsing
316    loader.testMethodPrefix = prefix
317    if suiteClass:
318        loader.suiteClass = suiteClass
319    return loader
320
321
322def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp_):
323    return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
324
325
326def makeSuite(testCaseClass, prefix='test', sortUsing=cmp_,
327              suiteClass=suite.TestSuite):
328    return _makeLoader(
329        prefix,
330        sortUsing,
331        suiteClass).loadTestsFromTestCase(testCaseClass)
332
333
334def findTestCases(module, prefix='test', sortUsing=cmp_,
335                  suiteClass=suite.TestSuite):
336    return _makeLoader(
337        prefix,
338        sortUsing,
339        suiteClass).loadTestsFromModule(module)
340