瀏覽代碼

Fixed #35849 -- Made ParallelTestSuite report correct error location.

David Winiecki 5 月之前
父節點
當前提交
661dfdd598
共有 3 個文件被更改,包括 158 次插入9 次删除
  1. 1 0
      AUTHORS
  2. 23 3
      django/test/runner.py
  3. 134 6
      tests/test_runner/test_parallel.py

+ 1 - 0
AUTHORS

@@ -282,6 +282,7 @@ answer newbie questions, and generally made Django that much better:
     David Sanders <dsanders11@ucsbalum.com>
     David Schein
     David Tulig <david.tulig@gmail.com>
+    David Winiecki <david.winiecki@gmail.com>
     David Winterbottom <david.winterbottom@gmail.com>
     David Wobrock <david.wobrock@gmail.com>
     Davide Ceretti <dav.ceretti@gmail.com>

+ 23 - 3
django/test/runner.py

@@ -12,6 +12,7 @@ import random
 import sys
 import textwrap
 import unittest
+import unittest.suite
 from collections import defaultdict
 from contextlib import contextmanager
 from importlib import import_module
@@ -292,7 +293,15 @@ failure and get a correct traceback.
 
     def addError(self, test, err):
         self.check_picklable(test, err)
-        self.events.append(("addError", self.test_index, err))
+
+        event_occurred_before_first_test = self.test_index == -1
+        if event_occurred_before_first_test and isinstance(
+            test, unittest.suite._ErrorHolder
+        ):
+            self.events.append(("addError", self.test_index, test.id(), err))
+        else:
+            self.events.append(("addError", self.test_index, err))
+
         super().addError(test, err)
 
     def addFailure(self, test, err):
@@ -558,8 +567,19 @@ class ParallelTestSuite(unittest.TestSuite):
         handler = getattr(result, event_name, None)
         if handler is None:
             return
-        test = tests[event[1]]
-        args = event[2:]
+        test_index = event[1]
+        event_occurred_before_first_test = test_index == -1
+        if (
+            event_name == "addError"
+            and event_occurred_before_first_test
+            and len(event) >= 4
+        ):
+            test_id = event[2]
+            test = unittest.suite._ErrorHolder(test_id)
+            args = event[3:]
+        else:
+            test = tests[test_index]
+            args = event[2:]
         handler(test, *args)
 
     def __iter__(self):

+ 134 - 6
tests/test_runner/test_parallel.py

@@ -1,9 +1,12 @@
 import pickle
 import sys
 import unittest
+from unittest.case import TestCase
+from unittest.result import TestResult
+from unittest.suite import TestSuite, _ErrorHolder
 
 from django.test import SimpleTestCase
-from django.test.runner import RemoteTestResult
+from django.test.runner import ParallelTestSuite, RemoteTestResult
 from django.utils.version import PY311, PY312
 
 try:
@@ -59,6 +62,18 @@ class SampleFailingSubtest(SimpleTestCase):
             self.fail("expected failure")
 
 
+class SampleErrorTest(SimpleTestCase):
+    @classmethod
+    def setUpClass(cls):
+        raise ValueError("woops")
+        super().setUpClass()
+
+    # This method name doesn't begin with "test" to prevent test discovery
+    # from seeing it.
+    def dummy_test(self):
+        raise AssertionError("SampleErrorTest.dummy_test() was called")
+
+
 class RemoteTestResultTest(SimpleTestCase):
     def _test_error_exc_info(self):
         try:
@@ -72,29 +87,70 @@ class RemoteTestResultTest(SimpleTestCase):
 
     def test_was_successful_one_success(self):
         result = RemoteTestResult()
-        result.addSuccess(None)
+        test = None
+        result.startTest(test)
+        try:
+            result.addSuccess(test)
+        finally:
+            result.stopTest(test)
         self.assertIs(result.wasSuccessful(), True)
 
     def test_was_successful_one_expected_failure(self):
         result = RemoteTestResult()
-        result.addExpectedFailure(None, self._test_error_exc_info())
+        test = None
+        result.startTest(test)
+        try:
+            result.addExpectedFailure(test, self._test_error_exc_info())
+        finally:
+            result.stopTest(test)
         self.assertIs(result.wasSuccessful(), True)
 
     def test_was_successful_one_skip(self):
         result = RemoteTestResult()
-        result.addSkip(None, "Skipped")
+        test = None
+        result.startTest(test)
+        try:
+            result.addSkip(test, "Skipped")
+        finally:
+            result.stopTest(test)
         self.assertIs(result.wasSuccessful(), True)
 
     @unittest.skipUnless(tblib is not None, "requires tblib to be installed")
     def test_was_successful_one_error(self):
         result = RemoteTestResult()
-        result.addError(None, self._test_error_exc_info())
+        test = None
+        result.startTest(test)
+        try:
+            result.addError(test, self._test_error_exc_info())
+        finally:
+            result.stopTest(test)
         self.assertIs(result.wasSuccessful(), False)
 
     @unittest.skipUnless(tblib is not None, "requires tblib to be installed")
     def test_was_successful_one_failure(self):
         result = RemoteTestResult()
-        result.addFailure(None, self._test_error_exc_info())
+        test = None
+        result.startTest(test)
+        try:
+            result.addFailure(test, self._test_error_exc_info())
+        finally:
+            result.stopTest(test)
+        self.assertIs(result.wasSuccessful(), False)
+
+    @unittest.skipUnless(tblib is not None, "requires tblib to be installed")
+    def test_add_error_before_first_test(self):
+        result = RemoteTestResult()
+        test_id = "test_foo (tests.test_foo.FooTest.test_foo)"
+        test = _ErrorHolder(test_id)
+        # Call addError() without a call to startTest().
+        result.addError(test, self._test_error_exc_info())
+
+        (event,) = result.events
+        self.assertEqual(event[0], "addError")
+        self.assertEqual(event[1], -1)
+        self.assertEqual(event[2], test_id)
+        (error_type, _, _) = event[3]
+        self.assertEqual(error_type, ValueError)
         self.assertIs(result.wasSuccessful(), False)
 
     def test_picklable(self):
@@ -161,3 +217,75 @@ class RemoteTestResultTest(SimpleTestCase):
         result = RemoteTestResult()
         result.addDuration(None, 2.3)
         self.assertEqual(result.collectedDurations, [("None", 2.3)])
+
+
+class ParallelTestSuiteTest(SimpleTestCase):
+    def test_handle_add_error_before_first_test(self):
+        dummy_subsuites = []
+        pts = ParallelTestSuite(dummy_subsuites, processes=2)
+        result = TestResult()
+        remote_result = RemoteTestResult()
+        test = SampleErrorTest(methodName="dummy_test")
+        suite = TestSuite([test])
+        suite.run(remote_result)
+        for event in remote_result.events:
+            pts.handle_event(result, tests=list(suite), event=event)
+
+        self.assertEqual(len(result.errors), 1)
+        actual_test, tb_and_details_str = result.errors[0]
+        self.assertIsInstance(actual_test, _ErrorHolder)
+        self.assertEqual(
+            actual_test.id(), "setUpClass (test_runner.test_parallel.SampleErrorTest)"
+        )
+        self.assertIn("Traceback (most recent call last):", tb_and_details_str)
+        self.assertIn("ValueError: woops", tb_and_details_str)
+
+    def test_handle_add_error_during_test(self):
+        dummy_subsuites = []
+        pts = ParallelTestSuite(dummy_subsuites, processes=2)
+        result = TestResult()
+        test = TestCase()
+        err = _test_error_exc_info()
+        event = ("addError", 0, err)
+        pts.handle_event(result, tests=[test], event=event)
+
+        self.assertEqual(len(result.errors), 1)
+        actual_test, tb_and_details_str = result.errors[0]
+        self.assertIsInstance(actual_test, TestCase)
+        self.assertEqual(actual_test.id(), "unittest.case.TestCase.runTest")
+        self.assertIn("Traceback (most recent call last):", tb_and_details_str)
+        self.assertIn("ValueError: woops", tb_and_details_str)
+
+    def test_handle_add_failure(self):
+        dummy_subsuites = []
+        pts = ParallelTestSuite(dummy_subsuites, processes=2)
+        result = TestResult()
+        test = TestCase()
+        err = _test_error_exc_info()
+        event = ("addFailure", 0, err)
+        pts.handle_event(result, tests=[test], event=event)
+
+        self.assertEqual(len(result.failures), 1)
+        actual_test, tb_and_details_str = result.failures[0]
+        self.assertIsInstance(actual_test, TestCase)
+        self.assertEqual(actual_test.id(), "unittest.case.TestCase.runTest")
+        self.assertIn("Traceback (most recent call last):", tb_and_details_str)
+        self.assertIn("ValueError: woops", tb_and_details_str)
+
+    def test_handle_add_success(self):
+        dummy_subsuites = []
+        pts = ParallelTestSuite(dummy_subsuites, processes=2)
+        result = TestResult()
+        test = TestCase()
+        event = ("addSuccess", 0)
+        pts.handle_event(result, tests=[test], event=event)
+
+        self.assertEqual(len(result.errors), 0)
+        self.assertEqual(len(result.failures), 0)
+
+
+def _test_error_exc_info():
+    try:
+        raise ValueError("woops")
+    except ValueError:
+        return sys.exc_info()