浏览代码

Fixed #29363 -- Added SimpleTestCase.assertWarnsMessage().

Morgan Aubert 7 年之前
父节点
当前提交
704443acac

+ 29 - 14
django/test/testcases.py

@@ -585,10 +585,23 @@ class SimpleTestCase(unittest.TestCase):
         )
 
     @contextmanager
-    def _assert_raises_message_cm(self, expected_exception, expected_message):
-        with self.assertRaises(expected_exception) as cm:
+    def _assert_raises_or_warns_cm(self, func, cm_attr, expected_exception, expected_message):
+        with func(expected_exception) as cm:
             yield cm
-        self.assertIn(expected_message, str(cm.exception))
+        self.assertIn(expected_message, str(getattr(cm, cm_attr)))
+
+    def _assertFooMessage(self, func, cm_attr, expected_exception, expected_message, *args, **kwargs):
+        callable_obj = None
+        if args:
+            callable_obj = args[0]
+            args = args[1:]
+        cm = self._assert_raises_or_warns_cm(func, cm_attr, expected_exception, expected_message)
+        # Assertion used in context manager fashion.
+        if callable_obj is None:
+            return cm
+        # Assertion was passed a callable.
+        with cm:
+            callable_obj(*args, **kwargs)
 
     def assertRaisesMessage(self, expected_exception, expected_message, *args, **kwargs):
         """
@@ -601,18 +614,20 @@ class SimpleTestCase(unittest.TestCase):
             args: Function to be called and extra positional args.
             kwargs: Extra kwargs.
         """
-        callable_obj = None
-        if args:
-            callable_obj = args[0]
-            args = args[1:]
+        return self._assertFooMessage(
+            self.assertRaises, 'exception', expected_exception, expected_message,
+            *args, **kwargs
+        )
 
-        cm = self._assert_raises_message_cm(expected_exception, expected_message)
-        # Assertion used in context manager fashion.
-        if callable_obj is None:
-            return cm
-        # Assertion was passed a callable.
-        with cm:
-            callable_obj(*args, **kwargs)
+    def assertWarnsMessage(self, expected_warning, expected_message, *args, **kwargs):
+        """
+        Same as assertRaisesMessage but for assertWarns() instead of
+        assertRaises().
+        """
+        return self._assertFooMessage(
+            self.assertWarns, 'warning', expected_warning, expected_message,
+            *args, **kwargs
+        )
 
     def assertFieldOutput(self, fieldclass, valid, invalid, field_args=None,
                           field_kwargs=None, empty_value=''):

+ 8 - 4
docs/internals/contributing/writing-code/coding-style.txt

@@ -62,10 +62,14 @@ Python style
 
 * In docstrings, follow the style of existing docstrings and :pep:`257`.
 
-* In tests, use :meth:`~django.test.SimpleTestCase.assertRaisesMessage` instead
-  of :meth:`~unittest.TestCase.assertRaises` so you can check the exception
-  message. Use :meth:`~unittest.TestCase.assertRaisesRegex` only if you need
-  regular expression matching.
+* In tests, use
+  :meth:`~django.test.SimpleTestCase.assertRaisesMessage` and
+  :meth:`~django.test.SimpleTestCase.assertWarnsMessage`
+  instead of :meth:`~unittest.TestCase.assertRaises` and
+  :meth:`~unittest.TestCase.assertWarns` so you can check the
+  exception or warning message. Use :meth:`~unittest.TestCase.assertRaisesRegex`
+  and :meth:`~unittest.TestCase.assertWarnsRegex` only if you need regular
+  expression matching.
 
 * In test docstrings, state the expected behavior that each test demonstrates.
   Don't include preambles such as "Tests that" or "Ensures that".

+ 3 - 0
docs/releases/2.1.txt

@@ -277,6 +277,9 @@ Tests
   dictionary as JSON if ``content_type='application/json'``. You can customize
   the JSON encoder with test client's ``json_encoder`` parameter.
 
+* The new :meth:`.SimpleTestCase.assertWarnsMessage` method is a simpler
+  version of :meth:`~unittest.TestCase.assertWarnsRegex`.
+
 URLs
 ~~~~
 

+ 11 - 0
docs/topics/testing/tools.txt

@@ -692,6 +692,8 @@ A subclass of :class:`unittest.TestCase` that adds this functionality:
 
   * Checking that a callable :meth:`raises a certain exception
     <SimpleTestCase.assertRaisesMessage>`.
+  * Checking that a callable :meth:`triggers a certain warning
+    <SimpleTestCase.assertWarnsMessage>`.
   * Testing form field :meth:`rendering and error treatment
     <SimpleTestCase.assertFieldOutput>`.
   * Testing :meth:`HTML responses for the presence/lack of a given fragment
@@ -1362,6 +1364,15 @@ your test suite.
         with self.assertRaisesMessage(ValueError, 'invalid literal for int()'):
             int('a')
 
+.. method:: SimpleTestCase.assertWarnsMessage(expected_warning, expected_message, callable, *args, **kwargs)
+            SimpleTestCase.assertWarnsMessage(expected_warning, expected_message)
+
+    .. versionadded:: 2.1
+
+    Analogous to :meth:`SimpleTestCase.assertRaisesMessage` but for
+    :meth:`~unittest.TestCase.assertWarnsRegex` instead of
+    :meth:`~unittest.TestCase.assertRaisesRegex`.
+
 .. method:: SimpleTestCase.assertFieldOutput(fieldclass, valid, invalid, field_args=None, field_kwargs=None, empty_value='')
 
     Asserts that a form field behaves correctly with various inputs.

+ 1 - 7
tests/admin_views/test_static_deprecation.py

@@ -1,5 +1,3 @@
-import warnings
-
 from django.contrib.admin.templatetags.admin_static import static
 from django.contrib.staticfiles.storage import staticfiles_storage
 from django.test import SimpleTestCase
@@ -19,12 +17,8 @@ class AdminStaticDeprecationTests(SimpleTestCase):
         old_url = staticfiles_storage.base_url
         staticfiles_storage.base_url = '/test/'
         try:
-            with warnings.catch_warnings(record=True) as recorded:
-                warnings.simplefilter('always')
+            with self.assertWarnsMessage(RemovedInDjango30Warning, msg):
                 url = static('path')
             self.assertEqual(url, '/test/path')
-            self.assertEqual(len(recorded), 1)
-            self.assertIs(recorded[0].category, RemovedInDjango30Warning)
-            self.assertEqual(str(recorded[0].message), msg)
         finally:
             staticfiles_storage.base_url = old_url

+ 8 - 6
tests/backends/postgresql/tests.py

@@ -1,5 +1,4 @@
 import unittest
-import warnings
 from unittest import mock
 
 from django.core.exceptions import ImproperlyConfigured
@@ -24,7 +23,14 @@ class Tests(TestCase):
         self.assertIsNone(nodb_conn.settings_dict['NAME'])
 
         # Now assume the 'postgres' db isn't available
-        with warnings.catch_warnings(record=True) as w:
+        msg = (
+            "Normally Django will use a connection to the 'postgres' database "
+            "to avoid running initialization queries against the production "
+            "database when it's not needed (for example, when running tests). "
+            "Django was unable to create a connection to the 'postgres' "
+            "database and will use the first PostgreSQL database instead."
+        )
+        with self.assertWarnsMessage(RuntimeWarning, msg):
             with mock.patch('django.db.backends.base.base.BaseDatabaseWrapper.connect',
                             side_effect=mocked_connect, autospec=True):
                 with mock.patch.object(
@@ -32,13 +38,9 @@ class Tests(TestCase):
                     'settings_dict',
                     {**connection.settings_dict, 'NAME': 'postgres'},
                 ):
-                    warnings.simplefilter('always', RuntimeWarning)
                     nodb_conn = connection._nodb_connection
         self.assertIsNotNone(nodb_conn.settings_dict['NAME'])
         self.assertEqual(nodb_conn.settings_dict['NAME'], connections['other'].settings_dict['NAME'])
-        # Check a RuntimeWarning has been emitted
-        self.assertEqual(len(w), 1)
-        self.assertEqual(w[0].message.__class__, RuntimeWarning)
 
     def test_database_name_too_long(self):
         from django.db.backends.postgresql.base import DatabaseWrapper

+ 3 - 6
tests/backends/tests.py

@@ -441,13 +441,10 @@ class BackendTestCase(TransactionTestCase):
                 cursor.execute("SELECT 3" + new_connection.features.bare_select_suffix)
                 cursor.execute("SELECT 4" + new_connection.features.bare_select_suffix)
 
-            with warnings.catch_warnings(record=True) as w:
+            msg = "Limit for query logging exceeded, only the last 3 queries will be returned."
+            with self.assertWarnsMessage(UserWarning, msg):
                 self.assertEqual(3, len(new_connection.queries))
-                self.assertEqual(1, len(w))
-                self.assertEqual(
-                    str(w[0].message),
-                    "Limit for query logging exceeded, only the last 3 queries will be returned."
-                )
+
         finally:
             BaseDatabaseWrapper.queries_limit = old_queries_limit
             new_connection.close()

+ 1 - 6
tests/cache/tests.py

@@ -10,7 +10,6 @@ import tempfile
 import threading
 import time
 import unittest
-import warnings
 from unittest import mock
 
 from django.conf import settings
@@ -632,12 +631,8 @@ class BaseCacheTests:
         cache.key_func = func
 
         try:
-            with warnings.catch_warnings(record=True) as w:
-                warnings.simplefilter("always")
+            with self.assertWarnsMessage(CacheKeyWarning, expected_warning):
                 cache.set(key, 'value')
-                self.assertEqual(len(w), 1)
-                self.assertIsInstance(w[0].message, CacheKeyWarning)
-                self.assertEqual(str(w[0].message.args[0]), expected_warning)
         finally:
             cache.key_func = old_func
 

+ 66 - 87
tests/deprecation/tests.py

@@ -23,107 +23,94 @@ class RenameMethodsTests(SimpleTestCase):
         Ensure a warning is raised upon class definition to suggest renaming
         the faulty method.
         """
-        with warnings.catch_warnings(record=True) as recorded:
-            warnings.simplefilter('always')
-
+        msg = '`Manager.old` method should be renamed `new`.'
+        with self.assertWarnsMessage(DeprecationWarning, msg):
             class Manager(metaclass=RenameManagerMethods):
                 def old(self):
                     pass
-            self.assertEqual(len(recorded), 1)
-            msg = str(recorded[0].message)
-            self.assertEqual(msg, '`Manager.old` method should be renamed `new`.')
 
     def test_get_new_defined(self):
         """
         Ensure `old` complains and not `new` when only `new` is defined.
         """
-        with warnings.catch_warnings(record=True) as recorded:
-            warnings.simplefilter('ignore')
+        class Manager(metaclass=RenameManagerMethods):
+            def new(self):
+                pass
+        manager = Manager()
 
-            class Manager(metaclass=RenameManagerMethods):
-                def new(self):
-                    pass
+        with warnings.catch_warnings(record=True) as recorded:
             warnings.simplefilter('always')
-            manager = Manager()
             manager.new()
-            self.assertEqual(len(recorded), 0)
+        self.assertEqual(len(recorded), 0)
+
+        msg = '`Manager.old` is deprecated, use `new` instead.'
+        with self.assertWarnsMessage(DeprecationWarning, msg):
             manager.old()
-            self.assertEqual(len(recorded), 1)
-            msg = str(recorded.pop().message)
-            self.assertEqual(msg, '`Manager.old` is deprecated, use `new` instead.')
 
     def test_get_old_defined(self):
         """
         Ensure `old` complains when only `old` is defined.
         """
-        with warnings.catch_warnings(record=True) as recorded:
-            warnings.simplefilter('ignore')
+        class Manager(metaclass=RenameManagerMethods):
+            def old(self):
+                pass
+        manager = Manager()
 
-            class Manager(metaclass=RenameManagerMethods):
-                def old(self):
-                    pass
+        with warnings.catch_warnings(record=True) as recorded:
             warnings.simplefilter('always')
-            manager = Manager()
             manager.new()
-            self.assertEqual(len(recorded), 0)
+        self.assertEqual(len(recorded), 0)
+
+        msg = '`Manager.old` is deprecated, use `new` instead.'
+        with self.assertWarnsMessage(DeprecationWarning, msg):
             manager.old()
-            self.assertEqual(len(recorded), 1)
-            msg = str(recorded.pop().message)
-            self.assertEqual(msg, '`Manager.old` is deprecated, use `new` instead.')
 
     def test_deprecated_subclass_renamed(self):
         """
         Ensure the correct warnings are raised when a class that didn't rename
         `old` subclass one that did.
         """
-        with warnings.catch_warnings(record=True) as recorded:
-            warnings.simplefilter('ignore')
+        class Renamed(metaclass=RenameManagerMethods):
+            def new(self):
+                pass
 
-            class Renamed(metaclass=RenameManagerMethods):
-                def new(self):
-                    pass
+        class Deprecated(Renamed):
+            def old(self):
+                super().old()
 
-            class Deprecated(Renamed):
-                def old(self):
-                    super().old()
-            warnings.simplefilter('always')
-            deprecated = Deprecated()
+        deprecated = Deprecated()
+
+        msg = '`Renamed.old` is deprecated, use `new` instead.'
+        with self.assertWarnsMessage(DeprecationWarning, msg):
             deprecated.new()
-            self.assertEqual(len(recorded), 1)
-            msg = str(recorded.pop().message)
-            self.assertEqual(msg, '`Renamed.old` is deprecated, use `new` instead.')
-            recorded[:] = []
+
+        msg = '`Deprecated.old` is deprecated, use `new` instead.'
+        with self.assertWarnsMessage(DeprecationWarning, msg):
             deprecated.old()
-            self.assertEqual(len(recorded), 2)
-            msgs = [str(warning.message) for warning in recorded]
-            self.assertEqual(msgs, [
-                '`Deprecated.old` is deprecated, use `new` instead.',
-                '`Renamed.old` is deprecated, use `new` instead.',
-            ])
 
     def test_renamed_subclass_deprecated(self):
         """
         Ensure the correct warnings are raised when a class that renamed
         `old` subclass one that didn't.
         """
-        with warnings.catch_warnings(record=True) as recorded:
-            warnings.simplefilter('ignore')
+        class Deprecated(metaclass=RenameManagerMethods):
+            def old(self):
+                pass
 
-            class Deprecated(metaclass=RenameManagerMethods):
-                def old(self):
-                    pass
+        class Renamed(Deprecated):
+            def new(self):
+                super().new()
 
-            class Renamed(Deprecated):
-                def new(self):
-                    super().new()
+        renamed = Renamed()
+
+        with warnings.catch_warnings(record=True) as recorded:
             warnings.simplefilter('always')
-            renamed = Renamed()
             renamed.new()
-            self.assertEqual(len(recorded), 0)
+        self.assertEqual(len(recorded), 0)
+
+        msg = '`Renamed.old` is deprecated, use `new` instead.'
+        with self.assertWarnsMessage(DeprecationWarning, msg):
             renamed.old()
-            self.assertEqual(len(recorded), 1)
-            msg = str(recorded.pop().message)
-            self.assertEqual(msg, '`Renamed.old` is deprecated, use `new` instead.')
 
     def test_deprecated_subclass_renamed_and_mixins(self):
         """
@@ -131,36 +118,30 @@ class RenameMethodsTests(SimpleTestCase):
         class that renamed `old` and mixins that may or may not have renamed
         `new`.
         """
-        with warnings.catch_warnings(record=True) as recorded:
-            warnings.simplefilter('ignore')
+        class Renamed(metaclass=RenameManagerMethods):
+            def new(self):
+                pass
 
-            class Renamed(metaclass=RenameManagerMethods):
-                def new(self):
-                    pass
+        class RenamedMixin:
+            def new(self):
+                super().new()
 
-            class RenamedMixin:
-                def new(self):
-                    super().new()
+        class DeprecatedMixin:
+            def old(self):
+                super().old()
 
-            class DeprecatedMixin:
-                def old(self):
-                    super().old()
+        class Deprecated(DeprecatedMixin, RenamedMixin, Renamed):
+            pass
 
-            class Deprecated(DeprecatedMixin, RenamedMixin, Renamed):
-                pass
-            warnings.simplefilter('always')
-            deprecated = Deprecated()
+        deprecated = Deprecated()
+
+        msg = '`RenamedMixin.old` is deprecated, use `new` instead.'
+        with self.assertWarnsMessage(DeprecationWarning, msg):
             deprecated.new()
-            self.assertEqual(len(recorded), 1)
-            msg = str(recorded.pop().message)
-            self.assertEqual(msg, '`RenamedMixin.old` is deprecated, use `new` instead.')
+
+        msg = '`DeprecatedMixin.old` is deprecated, use `new` instead.'
+        with self.assertWarnsMessage(DeprecationWarning, msg):
             deprecated.old()
-            self.assertEqual(len(recorded), 2)
-            msgs = [str(warning.message) for warning in recorded]
-            self.assertEqual(msgs, [
-                '`DeprecatedMixin.old` is deprecated, use `new` instead.',
-                '`RenamedMixin.old` is deprecated, use `new` instead.',
-            ])
 
 
 class DeprecationInstanceCheckTest(SimpleTestCase):
@@ -170,7 +151,5 @@ class DeprecationInstanceCheckTest(SimpleTestCase):
             deprecation_warning = RemovedInNextVersionWarning
 
         msg = '`Manager` is deprecated, use `fake.path.Foo` instead.'
-        with warnings.catch_warnings():
-            warnings.simplefilter('error', category=RemovedInNextVersionWarning)
-            with self.assertRaisesMessage(RemovedInNextVersionWarning, msg):
-                isinstance(object, Manager)
+        with self.assertWarnsMessage(RemovedInNextVersionWarning, msg):
+            isinstance(object, Manager)

+ 2 - 9
tests/fixtures/tests.py

@@ -495,16 +495,9 @@ class FixtureLoadingTests(DumpDataAssertMixin, TestCase):
         parent.
         """
         ProxySpy.objects.create(name='Paul')
-
-        with warnings.catch_warnings(record=True) as warning_list:
-            warnings.simplefilter('always')
+        msg = "fixtures.ProxySpy is a proxy model and won't be serialized."
+        with self.assertWarnsMessage(ProxyModelWarning, msg):
             self._dumpdata_assert(['fixtures.ProxySpy'], '[]')
-        warning = warning_list.pop()
-        self.assertEqual(warning.category, ProxyModelWarning)
-        self.assertEqual(
-            str(warning.message),
-            "fixtures.ProxySpy is a proxy model and won't be serialized."
-        )
 
     def test_dumpdata_proxy_with_concrete(self):
         """

+ 8 - 30
tests/fixtures_regress/tests.py

@@ -2,7 +2,6 @@
 import json
 import os
 import re
-import warnings
 from io import StringIO
 
 from django.core import management, serializers
@@ -209,19 +208,13 @@ class TestFixtures(TestCase):
         using explicit filename.
         Test for ticket #18213 -- warning conditions are caught correctly
         """
-        with warnings.catch_warnings(record=True) as warning_list:
-            warnings.simplefilter("always")
+        msg = "No fixture data found for 'bad_fixture2'. (File format may be invalid.)"
+        with self.assertWarnsMessage(RuntimeWarning, msg):
             management.call_command(
                 'loaddata',
                 'bad_fixture2.xml',
                 verbosity=0,
             )
-            warning = warning_list.pop()
-            self.assertEqual(warning.category, RuntimeWarning)
-            self.assertEqual(
-                str(warning.message),
-                "No fixture data found for 'bad_fixture2'. (File format may be invalid.)"
-            )
 
     def test_invalid_data_no_ext(self):
         """
@@ -229,55 +222,40 @@ class TestFixtures(TestCase):
         without file extension.
         Test for ticket #18213 -- warning conditions are caught correctly
         """
-        with warnings.catch_warnings(record=True) as warning_list:
-            warnings.simplefilter("always")
+        msg = "No fixture data found for 'bad_fixture2'. (File format may be invalid.)"
+        with self.assertWarnsMessage(RuntimeWarning, msg):
             management.call_command(
                 'loaddata',
                 'bad_fixture2',
                 verbosity=0,
             )
-            warning = warning_list.pop()
-            self.assertEqual(warning.category, RuntimeWarning)
-            self.assertEqual(
-                str(warning.message),
-                "No fixture data found for 'bad_fixture2'. (File format may be invalid.)"
-            )
 
     def test_empty(self):
         """
         Test for ticket #18213 -- Loading a fixture file with no data output a warning.
         Previously empty fixture raises an error exception, see ticket #4371.
         """
-        with warnings.catch_warnings(record=True) as warning_list:
-            warnings.simplefilter("always")
+        msg = "No fixture data found for 'empty'. (File format may be invalid.)"
+        with self.assertWarnsMessage(RuntimeWarning, msg):
             management.call_command(
                 'loaddata',
                 'empty',
                 verbosity=0,
             )
-            warning = warning_list.pop()
-            self.assertEqual(warning.category, RuntimeWarning)
-            self.assertEqual(str(warning.message), "No fixture data found for 'empty'. (File format may be invalid.)")
 
     def test_error_message(self):
         """
         Regression for #9011 - error message is correct.
         Change from error to warning for ticket #18213.
         """
-        with warnings.catch_warnings(record=True) as warning_list:
-            warnings.simplefilter("always")
+        msg = "No fixture data found for 'bad_fixture2'. (File format may be invalid.)"
+        with self.assertWarnsMessage(RuntimeWarning, msg):
             management.call_command(
                 'loaddata',
                 'bad_fixture2',
                 'animal',
                 verbosity=0,
             )
-            warning = warning_list.pop()
-            self.assertEqual(warning.category, RuntimeWarning)
-            self.assertEqual(
-                str(warning.message),
-                "No fixture data found for 'bad_fixture2'. (File format may be invalid.)"
-            )
 
     def test_pg_sequence_resetting_checks(self):
         """

+ 2 - 8
tests/forms_tests/tests/test_media.py

@@ -1,5 +1,3 @@
-import warnings
-
 from django.forms import CharField, Form, Media, MultiWidget, TextInput
 from django.template import Context, Template
 from django.test import SimpleTestCase, override_settings
@@ -540,10 +538,6 @@ class FormsMediaTestCase(SimpleTestCase):
                 self.assertEqual(Media.merge(list1, list2), expected)
 
     def test_merge_warning(self):
-        with warnings.catch_warnings(record=True) as w:
-            warnings.simplefilter('always')
+        msg = 'Detected duplicate Media files in an opposite order:\n1\n2'
+        with self.assertWarnsMessage(RuntimeWarning, msg):
             self.assertEqual(Media.merge([1, 2], [2, 1]), [1, 2])
-            self.assertEqual(
-                str(w[-1].message),
-                'Detected duplicate Media files in an opposite order:\n1\n2'
-            )

+ 6 - 11
tests/from_db_value/test_deprecated.py

@@ -1,6 +1,5 @@
-import warnings
-
 from django.test import TestCase
+from django.utils.deprecation import RemovedInDjango30Warning
 
 from .models import Cash, CashModelDeprecated
 
@@ -8,15 +7,11 @@ from .models import Cash, CashModelDeprecated
 class FromDBValueDeprecationTests(TestCase):
 
     def test_deprecation(self):
-        CashModelDeprecated.objects.create(cash='12.50')
-        with warnings.catch_warnings(record=True) as warns:
-            warnings.simplefilter('always')
-            instance = CashModelDeprecated.objects.get()
-        self.assertIsInstance(instance.cash, Cash)
-        self.assertEqual(len(warns), 1)
-        msg = str(warns[0].message)
-        self.assertEqual(
-            msg,
+        msg = (
             'Remove the context parameter from CashFieldDeprecated.from_db_value(). '
             'Support for it will be removed in Django 3.0.'
         )
+        CashModelDeprecated.objects.create(cash='12.50')
+        with self.assertWarnsMessage(RemovedInDjango30Warning, msg):
+            instance = CashModelDeprecated.objects.get()
+        self.assertIsInstance(instance.cash, Cash)

+ 5 - 9
tests/get_earliest_or_latest/tests.py

@@ -1,7 +1,7 @@
-import warnings
 from datetime import datetime
 
 from django.test import TestCase
+from django.utils.deprecation import RemovedInDjango30Warning
 
 from .models import Article, IndexErrorArticle, Person
 
@@ -169,16 +169,12 @@ class EarliestOrLatestTests(TestCase):
 
     def test_field_name_kwarg_deprecation(self):
         Person.objects.create(name='Deprecator', birthday=datetime(1950, 1, 1))
-        with warnings.catch_warnings(record=True) as warns:
-            warnings.simplefilter('always')
-            Person.objects.latest(field_name='birthday')
-
-        self.assertEqual(len(warns), 1)
-        self.assertEqual(
-            str(warns[0].message),
+        msg = (
             'The field_name keyword argument to earliest() and latest() '
-            'is deprecated in favor of passing positional arguments.',
+            'is deprecated in favor of passing positional arguments.'
         )
+        with self.assertWarnsMessage(RemovedInDjango30Warning, msg):
+            Person.objects.latest(field_name='birthday')
 
 
 class TestFirstLast(TestCase):

+ 2 - 12
tests/migrations/test_graph.py

@@ -1,5 +1,3 @@
-import warnings
-
 from django.db.migrations.exceptions import (
     CircularDependencyError, NodeNotFoundError,
 )
@@ -193,22 +191,14 @@ class GraphTests(SimpleTestCase):
             expected.append(child)
         leaf = expected[-1]
 
-        with warnings.catch_warnings(record=True) as w:
-            warnings.simplefilter('always', RuntimeWarning)
+        with self.assertWarnsMessage(RuntimeWarning, RECURSION_DEPTH_WARNING):
             forwards_plan = graph.forwards_plan(leaf)
 
-        self.assertEqual(len(w), 1)
-        self.assertTrue(issubclass(w[-1].category, RuntimeWarning))
-        self.assertEqual(str(w[-1].message), RECURSION_DEPTH_WARNING)
         self.assertEqual(expected, forwards_plan)
 
-        with warnings.catch_warnings(record=True) as w:
-            warnings.simplefilter('always', RuntimeWarning)
+        with self.assertWarnsMessage(RuntimeWarning, RECURSION_DEPTH_WARNING):
             backwards_plan = graph.backwards_plan(root)
 
-        self.assertEqual(len(w), 1)
-        self.assertTrue(issubclass(w[-1].category, RuntimeWarning))
-        self.assertEqual(str(w[-1].message), RECURSION_DEPTH_WARNING)
         self.assertEqual(expected[::-1], backwards_plan)
 
     def test_plan_invalid_node(self):

+ 4 - 9
tests/modeladmin/test_has_add_permission_obj_deprecation.py

@@ -1,5 +1,3 @@
-import warnings
-
 from django.contrib.admin.options import ModelAdmin, TabularInline
 from django.utils.deprecation import RemovedInDjango30Warning
 
@@ -52,12 +50,9 @@ class HasAddPermissionObjTests(CheckTestCase):
         class BandAdmin(ModelAdmin):
             inlines = [SongInlineAdmin]
 
-        with warnings.catch_warnings(record=True) as recorded:
-            warnings.simplefilter('always')
-            self.assertIsValid(BandAdmin, Band)
-        self.assertEqual(len(recorded), 1)
-        self.assertIs(recorded[0].category, RemovedInDjango30Warning)
-        self.assertEqual(str(recorded[0].message), (
+        msg = (
             "Update SongInlineAdmin.has_add_permission() to accept a "
             "positional `obj` argument."
-        ))
+        )
+        with self.assertWarnsMessage(RemovedInDjango30Warning, msg):
+            self.assertIsValid(BandAdmin, Band)

+ 9 - 17
tests/pagination/tests.py

@@ -1,5 +1,4 @@
 import unittest
-import warnings
 from datetime import datetime
 
 from django.core.paginator import (
@@ -359,20 +358,15 @@ class ModelPaginationTests(TestCase):
         self.assertIsInstance(p.object_list, list)
 
     def test_paginating_unordered_queryset_raises_warning(self):
-        with warnings.catch_warnings(record=True) as warns:
-            # Prevent the RuntimeWarning subclass from appearing as an
-            # exception due to the warnings.simplefilter() in runtests.py.
-            warnings.filterwarnings('always', category=UnorderedObjectListWarning)
-            Paginator(Article.objects.all(), 5)
-        self.assertEqual(len(warns), 1)
-        warning = warns[0]
-        self.assertEqual(str(warning.message), (
+        msg = (
             "Pagination may yield inconsistent results with an unordered "
             "object_list: <class 'pagination.models.Article'> QuerySet."
-        ))
+        )
+        with self.assertWarnsMessage(UnorderedObjectListWarning, msg) as cm:
+            Paginator(Article.objects.all(), 5)
         # The warning points at the Paginator caller (i.e. the stacklevel
         # is appropriate).
-        self.assertEqual(warning.filename, __file__)
+        self.assertEqual(cm.filename, __file__)
 
     def test_paginating_unordered_object_list_raises_warning(self):
         """
@@ -382,11 +376,9 @@ class ModelPaginationTests(TestCase):
         class ObjectList:
             ordered = False
         object_list = ObjectList()
-        with warnings.catch_warnings(record=True) as warns:
-            warnings.filterwarnings('always', category=UnorderedObjectListWarning)
-            Paginator(object_list, 5)
-        self.assertEqual(len(warns), 1)
-        self.assertEqual(str(warns[0].message), (
+        msg = (
             "Pagination may yield inconsistent results with an unordered "
             "object_list: {!r}.".format(object_list)
-        ))
+        )
+        with self.assertWarnsMessage(UnorderedObjectListWarning, msg):
+            Paginator(object_list, 5)

+ 3 - 8
tests/settings_tests/tests.py

@@ -1,7 +1,6 @@
 import os
 import sys
 import unittest
-import warnings
 from types import ModuleType
 from unittest import mock
 
@@ -349,15 +348,11 @@ class TestComplexSettingOverride(SimpleTestCase):
 
     def test_complex_override_warning(self):
         """Regression test for #19031"""
-        with warnings.catch_warnings(record=True) as w:
-            warnings.simplefilter("always")
-
+        msg = 'Overriding setting TEST_WARN can lead to unexpected behavior.'
+        with self.assertWarnsMessage(UserWarning, msg) as cm:
             with override_settings(TEST_WARN='override'):
                 self.assertEqual(settings.TEST_WARN, 'override')
-
-            self.assertEqual(len(w), 1)
-            self.assertEqual(w[0].filename, __file__)
-            self.assertEqual(str(w[0].message), 'Overriding setting TEST_WARN can lead to unexpected behavior.')
+        self.assertEqual(cm.filename, __file__)
 
 
 class SecureProxySslHeaderTest(SimpleTestCase):

+ 2 - 11
tests/staticfiles_tests/test_templatetag_deprecation.py

@@ -1,4 +1,3 @@
-import warnings
 from urllib.parse import urljoin
 
 from django.contrib.staticfiles import storage
@@ -22,24 +21,16 @@ class StaticDeprecationTests(SimpleTestCase):
     def test_templatetag_deprecated(self):
         msg = '{% load staticfiles %} is deprecated in favor of {% load static %}.'
         template = "{% load staticfiles %}{% static 'main.js' %}"
-        with warnings.catch_warnings(record=True) as recorded:
-            warnings.simplefilter('always')
+        with self.assertWarnsMessage(RemovedInDjango30Warning, msg):
             template = Template(template)
         rendered = template.render(Context())
         self.assertEqual(rendered, 'https://example.com/assets/main.js')
-        self.assertEqual(len(recorded), 1)
-        self.assertIs(recorded[0].category, RemovedInDjango30Warning)
-        self.assertEqual(str(recorded[0].message), msg)
 
     def test_static_deprecated(self):
         msg = (
             'django.contrib.staticfiles.templatetags.static() is deprecated in '
             'favor of django.templatetags.static.static().'
         )
-        with warnings.catch_warnings(record=True) as recorded:
-            warnings.simplefilter('always')
+        with self.assertWarnsMessage(RemovedInDjango30Warning, msg):
             url = static('main.js')
         self.assertEqual(url, 'https://example.com/assets/main.js')
-        self.assertEqual(len(recorded), 1)
-        self.assertIs(recorded[0].category, RemovedInDjango30Warning)
-        self.assertEqual(str(recorded[0].message), msg)

+ 25 - 0
tests/test_utils/tests.py

@@ -1,5 +1,6 @@
 import os
 import unittest
+import warnings
 from io import StringIO
 from unittest import mock
 
@@ -864,6 +865,30 @@ class AssertRaisesMsgTest(SimpleTestCase):
             func1()
 
 
+class AssertWarnsMessageTests(SimpleTestCase):
+
+    def test_context_manager(self):
+        with self.assertWarnsMessage(UserWarning, 'Expected message'):
+            warnings.warn('Expected message', UserWarning)
+
+    def test_context_manager_failure(self):
+        msg = "Expected message' not found in 'Unexpected message'"
+        with self.assertRaisesMessage(AssertionError, msg):
+            with self.assertWarnsMessage(UserWarning, 'Expected message'):
+                warnings.warn('Unexpected message', UserWarning)
+
+    def test_callable(self):
+        def func():
+            warnings.warn('Expected message', UserWarning)
+        self.assertWarnsMessage(UserWarning, 'Expected message', func)
+
+    def test_special_re_chars(self):
+        def func1():
+            warnings.warn('[.*x+]y?', UserWarning)
+        with self.assertWarnsMessage(UserWarning, '[.*x+]y?'):
+            func1()
+
+
 class AssertFieldOutputTests(SimpleTestCase):
 
     def test_assert_field_output(self):

+ 8 - 27
tests/timezones/tests.py

@@ -1,7 +1,6 @@
 import datetime
 import re
 import sys
-import warnings
 from contextlib import contextmanager
 from unittest import SkipTest, skipIf
 from xml.dom.minidom import parseString
@@ -226,17 +225,13 @@ class LegacyDatabaseTests(TestCase):
 
 @override_settings(TIME_ZONE='Africa/Nairobi', USE_TZ=True)
 class NewDatabaseTests(TestCase):
+    naive_warning = 'DateTimeField Event.dt received a naive datetime'
 
     @requires_tz_support
     def test_naive_datetime(self):
         dt = datetime.datetime(2011, 9, 1, 13, 20, 30)
-        with warnings.catch_warnings(record=True) as recorded:
-            warnings.simplefilter('always')
+        with self.assertWarnsMessage(RuntimeWarning, self.naive_warning):
             Event.objects.create(dt=dt)
-            self.assertEqual(len(recorded), 1)
-            msg = str(recorded[0].message)
-            self.assertTrue(msg.startswith("DateTimeField Event.dt received "
-                                           "a naive datetime"))
         event = Event.objects.get()
         # naive datetimes are interpreted in local time
         self.assertEqual(event.dt, dt.replace(tzinfo=EAT))
@@ -244,26 +239,16 @@ class NewDatabaseTests(TestCase):
     @requires_tz_support
     def test_datetime_from_date(self):
         dt = datetime.date(2011, 9, 1)
-        with warnings.catch_warnings(record=True) as recorded:
-            warnings.simplefilter('always')
+        with self.assertWarnsMessage(RuntimeWarning, self.naive_warning):
             Event.objects.create(dt=dt)
-            self.assertEqual(len(recorded), 1)
-            msg = str(recorded[0].message)
-            self.assertTrue(msg.startswith("DateTimeField Event.dt received "
-                                           "a naive datetime"))
         event = Event.objects.get()
         self.assertEqual(event.dt, datetime.datetime(2011, 9, 1, tzinfo=EAT))
 
     @requires_tz_support
     def test_naive_datetime_with_microsecond(self):
         dt = datetime.datetime(2011, 9, 1, 13, 20, 30, 405060)
-        with warnings.catch_warnings(record=True) as recorded:
-            warnings.simplefilter('always')
+        with self.assertWarnsMessage(RuntimeWarning, self.naive_warning):
             Event.objects.create(dt=dt)
-            self.assertEqual(len(recorded), 1)
-            msg = str(recorded[0].message)
-            self.assertTrue(msg.startswith("DateTimeField Event.dt received "
-                                           "a naive datetime"))
         event = Event.objects.get()
         # naive datetimes are interpreted in local time
         self.assertEqual(event.dt, dt.replace(tzinfo=EAT))
@@ -330,17 +315,13 @@ class NewDatabaseTests(TestCase):
         dt = datetime.datetime(2011, 9, 1, 12, 20, 30, tzinfo=EAT)
         Event.objects.create(dt=dt)
         dt = dt.replace(tzinfo=None)
-        with warnings.catch_warnings(record=True) as recorded:
-            warnings.simplefilter('always')
-            # naive datetimes are interpreted in local time
+        # naive datetimes are interpreted in local time
+        with self.assertWarnsMessage(RuntimeWarning, self.naive_warning):
             self.assertEqual(Event.objects.filter(dt__exact=dt).count(), 1)
+        with self.assertWarnsMessage(RuntimeWarning, self.naive_warning):
             self.assertEqual(Event.objects.filter(dt__lte=dt).count(), 1)
+        with self.assertWarnsMessage(RuntimeWarning, self.naive_warning):
             self.assertEqual(Event.objects.filter(dt__gt=dt).count(), 0)
-            self.assertEqual(len(recorded), 3)
-            for warning in recorded:
-                msg = str(warning.message)
-                self.assertTrue(msg.startswith("DateTimeField Event.dt "
-                                               "received a naive datetime"))
 
     @skipUnlessDBFeature('has_zoneinfo_database')
     def test_query_datetime_lookups(self):