1"""Loading unittests.""" 2 3import functools 4import os 5import re 6import sys 7import traceback 8import types 9import unittest 10 11from fnmatch import fnmatch 12 13from unittest2 import case, suite, cmp_ 14 15try: 16 from os.path import relpath 17except ImportError: 18 from unittest2.compatibility import relpath 19 20__unittest = True 21 22# what about .pyc or .pyo (etc) 23# we would need to avoid loading the same tests multiple times 24# from '.py', '.pyc' *and* '.pyo' 25VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE) 26 27 28def _make_failed_import_test(name, suiteClass): 29 message = 'Failed to import test module: %s' % name 30 if hasattr(traceback, 'format_exc'): 31 # Python 2.3 compatibility 32 # format_exc returns two frames of discover.py as well 33 message += '\n%s' % traceback.format_exc() 34 return _make_failed_test('ModuleImportFailure', name, ImportError(message), 35 suiteClass) 36 37 38def _make_failed_load_tests(name, exception, suiteClass): 39 return _make_failed_test('LoadTestsFailure', name, exception, suiteClass) 40 41 42def _make_failed_test(classname, methodname, exception, suiteClass): 43 def testFailure(self): 44 raise exception 45 attrs = {methodname: testFailure} 46 TestClass = type(classname, (case.TestCase,), attrs) 47 return suiteClass((TestClass(methodname),)) 48 49 50class TestLoader(unittest.TestLoader): 51 """ 52 This class is responsible for loading tests according to various criteria 53 and returning them wrapped in a TestSuite 54 """ 55 56 def __init__(self): 57 self.testMethodPrefix = 'test' 58 self.sortTestMethodsUsing = cmp_ 59 self.suiteClass = suite.TestSuite 60 self._top_level_dir = None 61 62 def loadTestsFromTestCase(self, testCaseClass): 63 """Return a suite of all tests cases contained in testCaseClass""" 64 if issubclass(testCaseClass, suite.TestSuite): 65 raise TypeError("Test cases should not be derived from TestSuite." 66 " Maybe you meant to derive from TestCase?") 67 testCaseNames = self.getTestCaseNames(testCaseClass) 68 if not testCaseNames and hasattr(testCaseClass, 'runTest'): 69 testCaseNames = ['runTest'] 70 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames)) 71 return loaded_suite 72 73 def loadTestsFromModule(self, module, use_load_tests=True): 74 """Return a suite of all tests cases contained in the given module""" 75 tests = [] 76 for name in dir(module): 77 obj = getattr(module, name) 78 if isinstance(obj, type) and issubclass(obj, unittest.TestCase): 79 tests.append(self.loadTestsFromTestCase(obj)) 80 81 load_tests = getattr(module, 'load_tests', None) 82 tests = self.suiteClass(tests) 83 if use_load_tests and load_tests is not None: 84 try: 85 return load_tests(self, tests, None) 86 except Exception as e: 87 return _make_failed_load_tests(module.__name__, e, 88 self.suiteClass) 89 return tests 90 91 def loadTestsFromName(self, name, module=None): 92 """Return a suite of all tests cases given a string specifier. 93 94 The name may resolve either to a module, a test case class, a 95 test method within a test case class, or a callable object which 96 returns a TestCase or TestSuite instance. 97 98 The method optionally resolves the names relative to a given module. 99 """ 100 parts = name.split('.') 101 if module is None: 102 parts_copy = parts[:] 103 while parts_copy: 104 try: 105 module = __import__('.'.join(parts_copy)) 106 break 107 except ImportError: 108 del parts_copy[-1] 109 if not parts_copy: 110 raise 111 parts = parts[1:] 112 obj = module 113 for part in parts: 114 parent, obj = obj, getattr(obj, part) 115 116 if isinstance(obj, types.ModuleType): 117 return self.loadTestsFromModule(obj) 118 elif isinstance(obj, type) and issubclass(obj, unittest.TestCase): 119 return self.loadTestsFromTestCase(obj) 120 elif (isinstance(obj, (types.MethodType, types.FunctionType)) and 121 isinstance(parent, type) and 122 issubclass(parent, case.TestCase)): 123 return self.suiteClass([parent(obj.__name__)]) 124 elif isinstance(obj, unittest.TestSuite): 125 return obj 126 elif hasattr(obj, '__call__'): 127 test = obj() 128 if isinstance(test, unittest.TestSuite): 129 return test 130 elif isinstance(test, unittest.TestCase): 131 return self.suiteClass([test]) 132 else: 133 raise TypeError("calling %s returned %s, not a test" % 134 (obj, test)) 135 else: 136 raise TypeError("don't know how to make test from: %s" % obj) 137 138 def loadTestsFromNames(self, names, module=None): 139 """Return a suite of all tests cases found using the given sequence 140 of string specifiers. See 'loadTestsFromName()'. 141 """ 142 suites = [self.loadTestsFromName(name, module) for name in names] 143 return self.suiteClass(suites) 144 145 def getTestCaseNames(self, testCaseClass): 146 """Return a sorted sequence of method names found within testCaseClass 147 """ 148 def isTestMethod(attrname, testCaseClass=testCaseClass, 149 prefix=self.testMethodPrefix): 150 return attrname.startswith(prefix) and \ 151 hasattr(getattr(testCaseClass, attrname), '__call__') 152 testFnNames = list(filter(isTestMethod, dir(testCaseClass))) 153 if self.sortTestMethodsUsing: 154 testFnNames.sort( 155 key=functools.cmp_to_key( 156 self.sortTestMethodsUsing)) 157 return testFnNames 158 159 def discover(self, start_dir, pattern='test*.py', top_level_dir=None): 160 """Find and return all test modules from the specified start 161 directory, recursing into subdirectories to find them. Only test files 162 that match the pattern will be loaded. (Using shell style pattern 163 matching.) 164 165 All test modules must be importable from the top level of the project. 166 If the start directory is not the top level directory then the top 167 level directory must be specified separately. 168 169 If a test package name (directory with '__init__.py') matches the 170 pattern then the package will be checked for a 'load_tests' function. If 171 this exists then it will be called with loader, tests, pattern. 172 173 If load_tests exists then discovery does *not* recurse into the package, 174 load_tests is responsible for loading all tests in the package. 175 176 The pattern is deliberately not stored as a loader attribute so that 177 packages can continue discovery themselves. top_level_dir is stored so 178 load_tests does not need to pass this argument in to loader.discover(). 179 """ 180 set_implicit_top = False 181 if top_level_dir is None and self._top_level_dir is not None: 182 # make top_level_dir optional if called from load_tests in a 183 # package 184 top_level_dir = self._top_level_dir 185 elif top_level_dir is None: 186 set_implicit_top = True 187 top_level_dir = start_dir 188 189 top_level_dir = os.path.abspath(top_level_dir) 190 191 if top_level_dir not in sys.path: 192 # all test modules must be importable from the top level directory 193 # should we *unconditionally* put the start directory in first 194 # in sys.path to minimise likelihood of conflicts between installed 195 # modules and development versions? 196 sys.path.insert(0, top_level_dir) 197 self._top_level_dir = top_level_dir 198 199 is_not_importable = False 200 if os.path.isdir(os.path.abspath(start_dir)): 201 start_dir = os.path.abspath(start_dir) 202 if start_dir != top_level_dir: 203 is_not_importable = not os.path.isfile( 204 os.path.join(start_dir, '__init__.py')) 205 else: 206 # support for discovery from dotted module names 207 try: 208 __import__(start_dir) 209 except ImportError: 210 is_not_importable = True 211 else: 212 the_module = sys.modules[start_dir] 213 top_part = start_dir.split('.')[0] 214 start_dir = os.path.abspath( 215 os.path.dirname((the_module.__file__))) 216 if set_implicit_top: 217 self._top_level_dir = os.path.abspath(os.path.dirname( 218 os.path.dirname(sys.modules[top_part].__file__))) 219 sys.path.remove(top_level_dir) 220 221 if is_not_importable: 222 raise ImportError( 223 'Start directory is not importable: %r' % 224 start_dir) 225 226 tests = list(self._find_tests(start_dir, pattern)) 227 return self.suiteClass(tests) 228 229 def _get_name_from_path(self, path): 230 path = os.path.splitext(os.path.normpath(path))[0] 231 232 _relpath = relpath(path, self._top_level_dir) 233 assert not os.path.isabs(_relpath), "Path must be within the project" 234 assert not _relpath.startswith('..'), "Path must be within the project" 235 236 name = _relpath.replace(os.path.sep, '.') 237 return name 238 239 def _get_module_from_name(self, name): 240 __import__(name) 241 return sys.modules[name] 242 243 def _match_path(self, path, full_path, pattern): 244 # override this method to use alternative matching strategy 245 return fnmatch(path, pattern) 246 247 def _find_tests(self, start_dir, pattern): 248 """Used by discovery. Yields test suites it loads.""" 249 paths = os.listdir(start_dir) 250 251 for path in paths: 252 full_path = os.path.join(start_dir, path) 253 if os.path.isfile(full_path): 254 if not VALID_MODULE_NAME.match(path): 255 # valid Python identifiers only 256 continue 257 if not self._match_path(path, full_path, pattern): 258 continue 259 # if the test file matches, load it 260 name = self._get_name_from_path(full_path) 261 try: 262 module = self._get_module_from_name(name) 263 except: 264 yield _make_failed_import_test(name, self.suiteClass) 265 else: 266 mod_file = os.path.abspath( 267 getattr(module, '__file__', full_path)) 268 realpath = os.path.splitext(mod_file)[0] 269 fullpath_noext = os.path.splitext(full_path)[0] 270 if realpath.lower() != fullpath_noext.lower(): 271 module_dir = os.path.dirname(realpath) 272 mod_name = os.path.splitext( 273 os.path.basename(full_path))[0] 274 expected_dir = os.path.dirname(full_path) 275 msg = ( 276 "%r module incorrectly imported from %r. Expected %r. " 277 "Is this module globally installed?") 278 raise ImportError(msg % 279 (mod_name, module_dir, expected_dir)) 280 yield self.loadTestsFromModule(module) 281 elif os.path.isdir(full_path): 282 if not os.path.isfile(os.path.join(full_path, '__init__.py')): 283 continue 284 285 load_tests = None 286 tests = None 287 if fnmatch(path, pattern): 288 # only check load_tests if the package directory itself 289 # matches the filter 290 name = self._get_name_from_path(full_path) 291 package = self._get_module_from_name(name) 292 load_tests = getattr(package, 'load_tests', None) 293 tests = self.loadTestsFromModule( 294 package, use_load_tests=False) 295 296 if load_tests is None: 297 if tests is not None: 298 # tests loaded from package file 299 yield tests 300 # recurse into the package 301 for test in self._find_tests(full_path, pattern): 302 yield test 303 else: 304 try: 305 yield load_tests(self, tests, pattern) 306 except Exception as e: 307 yield _make_failed_load_tests(package.__name__, e, 308 self.suiteClass) 309 310defaultTestLoader = TestLoader() 311 312 313def _makeLoader(prefix, sortUsing, suiteClass=None): 314 loader = TestLoader() 315 loader.sortTestMethodsUsing = sortUsing 316 loader.testMethodPrefix = prefix 317 if suiteClass: 318 loader.suiteClass = suiteClass 319 return loader 320 321 322def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp_): 323 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass) 324 325 326def makeSuite(testCaseClass, prefix='test', sortUsing=cmp_, 327 suiteClass=suite.TestSuite): 328 return _makeLoader( 329 prefix, 330 sortUsing, 331 suiteClass).loadTestsFromTestCase(testCaseClass) 332 333 334def findTestCases(module, prefix='test', sortUsing=cmp_, 335 suiteClass=suite.TestSuite): 336 return _makeLoader( 337 prefix, 338 sortUsing, 339 suiteClass).loadTestsFromModule(module) 340