Browse Source

Fixed #32489 -- Added iter_test_cases() to iterate over a TestSuite.

This also makes partition_suite_by_type(), partition_suite_by_case(),
filter_tests_by_tags(), and DiscoverRunner._get_databases() to use
iter_test_cases().
Chris Jerdonek 4 years ago
parent
commit
22c9af0eae
3 changed files with 128 additions and 48 deletions
  1. 35 48
      django/test/runner.py
  2. 12 0
      django/test/utils.py
  3. 81 0
      tests/test_runner/tests.py

+ 35 - 48
django/test/runner.py

@@ -16,9 +16,9 @@ from django.core.management import call_command
 from django.db import connections
 from django.test import SimpleTestCase, TestCase
 from django.test.utils import (
-    NullTimeKeeper, TimeKeeper, setup_databases as _setup_databases,
-    setup_test_environment, teardown_databases as _teardown_databases,
-    teardown_test_environment,
+    NullTimeKeeper, TimeKeeper, iter_test_cases,
+    setup_databases as _setup_databases, setup_test_environment,
+    teardown_databases as _teardown_databases, teardown_test_environment,
 )
 from django.utils.datastructures import OrderedSet
 
@@ -683,19 +683,16 @@ class DiscoverRunner:
 
     def _get_databases(self, suite):
         databases = {}
-        for test in suite:
-            if isinstance(test, unittest.TestCase):
-                test_databases = getattr(test, 'databases', None)
-                if test_databases == '__all__':
-                    test_databases = connections
-                if test_databases:
-                    serialized_rollback = getattr(test, 'serialized_rollback', False)
-                    databases.update(
-                        (alias, serialized_rollback or databases.get(alias, False))
-                        for alias in test_databases
-                    )
-            else:
-                databases.update(self._get_databases(test))
+        for test in iter_test_cases(suite):
+            test_databases = getattr(test, 'databases', None)
+            if test_databases == '__all__':
+                test_databases = connections
+            if test_databases:
+                serialized_rollback = getattr(test, 'serialized_rollback', False)
+                databases.update(
+                    (alias, serialized_rollback or databases.get(alias, False))
+                    for alias in test_databases
+                )
         return databases
 
     def get_databases(self, suite):
@@ -800,49 +797,39 @@ def partition_suite_by_type(suite, classes, bins, reverse=False):
     Tests of type classes[i] are added to bins[i],
     tests with no match found in classes are place in bins[-1]
     """
-    suite_class = type(suite)
-    if reverse:
-        suite = reversed(tuple(suite))
-    for test in suite:
-        if isinstance(test, suite_class):
-            partition_suite_by_type(test, classes, bins, reverse=reverse)
+    for test in iter_test_cases(suite, reverse=reverse):
+        for i in range(len(classes)):
+            if isinstance(test, classes[i]):
+                bins[i].add(test)
+                break
         else:
-            for i in range(len(classes)):
-                if isinstance(test, classes[i]):
-                    bins[i].add(test)
-                    break
-            else:
-                bins[-1].add(test)
+            bins[-1].add(test)
 
 
 def partition_suite_by_case(suite):
     """Partition a test suite by test case, preserving the order of tests."""
-    groups = []
+    subsuites = []
     suite_class = type(suite)
-    for test_type, test_group in itertools.groupby(suite, type):
-        if issubclass(test_type, unittest.TestCase):
-            groups.append(suite_class(test_group))
-        else:
-            for item in test_group:
-                groups.extend(partition_suite_by_case(item))
-    return groups
+    tests = iter_test_cases(suite)
+    for test_type, test_group in itertools.groupby(tests, type):
+        subsuite = suite_class(test_group)
+        subsuites.append(subsuite)
+
+    return subsuites
 
 
 def filter_tests_by_tags(suite, tags, exclude_tags):
     suite_class = type(suite)
     filtered_suite = suite_class()
 
-    for test in suite:
-        if isinstance(test, suite_class):
-            filtered_suite.addTests(filter_tests_by_tags(test, tags, exclude_tags))
-        else:
-            test_tags = set(getattr(test, 'tags', set()))
-            test_fn_name = getattr(test, '_testMethodName', str(test))
-            test_fn = getattr(test, test_fn_name, test)
-            test_fn_tags = set(getattr(test_fn, 'tags', set()))
-            all_tags = test_tags.union(test_fn_tags)
-            matched_tags = all_tags.intersection(tags)
-            if (matched_tags or not tags) and not all_tags.intersection(exclude_tags):
-                filtered_suite.addTest(test)
+    for test in iter_test_cases(suite):
+        test_tags = set(getattr(test, 'tags', set()))
+        test_fn_name = getattr(test, '_testMethodName', str(test))
+        test_fn = getattr(test, test_fn_name, test)
+        test_fn_tags = set(getattr(test_fn, 'tags', set()))
+        all_tags = test_tags.union(test_fn_tags)
+        matched_tags = all_tags.intersection(tags)
+        if (matched_tags or not tags) and not all_tags.intersection(exclude_tags):
+            filtered_suite.addTest(test)
 
     return filtered_suite

+ 12 - 0
django/test/utils.py

@@ -235,6 +235,18 @@ def setup_databases(
     return old_names
 
 
+def iter_test_cases(suite, reverse=False):
+    """Return an iterator over a test suite's unittest.TestCase objects."""
+    if reverse:
+        suite = reversed(tuple(suite))
+    for test in suite:
+        if isinstance(test, TestCase):
+            yield test
+        else:
+            # Otherwise, assume it is a test suite.
+            yield from iter_test_cases(test, reverse=reverse)
+
+
 def dependency_ordered(test_databases, dependencies):
     """
     Reorder test_databases into an order that honors the dependencies

+ 81 - 0
tests/test_runner/tests.py

@@ -18,12 +18,93 @@ from django.test.runner import DiscoverRunner
 from django.test.testcases import connections_support_transactions
 from django.test.utils import (
     captured_stderr, dependency_ordered, get_unique_databases_and_mirrors,
+    iter_test_cases,
 )
 from django.utils.deprecation import RemovedInDjango50Warning
 
 from .models import B, Person, Through
 
 
+class MySuite:
+    def __init__(self):
+        self.tests = []
+
+    def addTest(self, test):
+        self.tests.append(test)
+
+    def __iter__(self):
+        yield from self.tests
+
+
+class IterTestCasesTests(unittest.TestCase):
+    def make_test_suite(self, suite=None, suite_class=None):
+        if suite_class is None:
+            suite_class = unittest.TestSuite
+        if suite is None:
+            suite = suite_class()
+
+        class Tests1(unittest.TestCase):
+            def test1(self):
+                pass
+
+            def test2(self):
+                pass
+
+        class Tests2(unittest.TestCase):
+            def test1(self):
+                pass
+
+            def test2(self):
+                pass
+
+        loader = unittest.defaultTestLoader
+        for test_cls in (Tests1, Tests2):
+            tests = loader.loadTestsFromTestCase(test_cls)
+            subsuite = suite_class()
+            # Only use addTest() to simplify testing a custom TestSuite.
+            for test in tests:
+                subsuite.addTest(test)
+            suite.addTest(subsuite)
+
+        return suite
+
+    def assertTestNames(self, tests, expected):
+        # Each test.id() has a form like the following:
+        # "test_runner.tests.IterTestCasesTests.test_iter_test_cases.<locals>.Tests1.test1".
+        # It suffices to check only the last two parts.
+        names = ['.'.join(test.id().split('.')[-2:]) for test in tests]
+        self.assertEqual(names, expected)
+
+    def test_basic(self):
+        suite = self.make_test_suite()
+        tests = iter_test_cases(suite)
+        self.assertTestNames(tests, expected=[
+            'Tests1.test1', 'Tests1.test2', 'Tests2.test1', 'Tests2.test2',
+        ])
+
+    def test_reverse(self):
+        suite = self.make_test_suite()
+        tests = iter_test_cases(suite, reverse=True)
+        self.assertTestNames(tests, expected=[
+            'Tests2.test2', 'Tests2.test1', 'Tests1.test2', 'Tests1.test1',
+        ])
+
+    def test_custom_test_suite_class(self):
+        suite = self.make_test_suite(suite_class=MySuite)
+        tests = iter_test_cases(suite)
+        self.assertTestNames(tests, expected=[
+            'Tests1.test1', 'Tests1.test2', 'Tests2.test1', 'Tests2.test2',
+        ])
+
+    def test_mixed_test_suite_classes(self):
+        suite = self.make_test_suite(suite=MySuite())
+        child_suite = list(suite)[0]
+        self.assertNotIsInstance(child_suite, MySuite)
+        tests = list(iter_test_cases(suite))
+        self.assertEqual(len(tests), 4)
+        self.assertNotIsInstance(tests[0], unittest.TestSuite)
+
+
 class DependencyOrderingTests(unittest.TestCase):
 
     def test_simple_dependencies(self):