|
@@ -4,9 +4,11 @@ import posixpath
|
|
|
import sys
|
|
|
import threading
|
|
|
import unittest
|
|
|
+import warnings
|
|
|
from collections import Counter
|
|
|
from contextlib import contextmanager
|
|
|
from copy import copy
|
|
|
+from difflib import get_close_matches
|
|
|
from functools import wraps
|
|
|
from unittest.util import safe_repr
|
|
|
from urllib.parse import (
|
|
@@ -17,7 +19,7 @@ from urllib.request import url2pathname
|
|
|
from django.apps import apps
|
|
|
from django.conf import settings
|
|
|
from django.core import mail
|
|
|
-from django.core.exceptions import ValidationError
|
|
|
+from django.core.exceptions import ImproperlyConfigured, ValidationError
|
|
|
from django.core.files import locks
|
|
|
from django.core.handlers.wsgi import WSGIHandler, get_path_info
|
|
|
from django.core.management import call_command
|
|
@@ -36,6 +38,7 @@ from django.test.utils import (
|
|
|
override_settings,
|
|
|
)
|
|
|
from django.utils.decorators import classproperty
|
|
|
+from django.utils.deprecation import RemovedInDjango31Warning
|
|
|
from django.views.static import serve
|
|
|
|
|
|
__all__ = ('TestCase', 'TransactionTestCase',
|
|
@@ -133,16 +136,31 @@ class _AssertTemplateNotUsedContext(_AssertTemplateUsedContext):
|
|
|
|
|
|
|
|
|
class _CursorFailure:
|
|
|
- def __init__(self, cls_name, wrapped):
|
|
|
- self.cls_name = cls_name
|
|
|
+ def __init__(self, wrapped, message):
|
|
|
self.wrapped = wrapped
|
|
|
+ self.message = message
|
|
|
|
|
|
def __call__(self):
|
|
|
- raise AssertionError(
|
|
|
- "Database queries aren't allowed in SimpleTestCase. "
|
|
|
- "Either use TestCase or TransactionTestCase to ensure proper test isolation or "
|
|
|
- "set %s.allow_database_queries to True to silence this failure." % self.cls_name
|
|
|
- )
|
|
|
+ raise AssertionError(self.message)
|
|
|
+
|
|
|
+
|
|
|
+class _SimpleTestCaseDatabasesDescriptor:
|
|
|
+ """Descriptor for SimpleTestCase.allow_database_queries deprecation."""
|
|
|
+ def __get__(self, instance, cls=None):
|
|
|
+ try:
|
|
|
+ allow_database_queries = cls.allow_database_queries
|
|
|
+ except AttributeError:
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ msg = (
|
|
|
+ '`SimpleTestCase.allow_database_queries` is deprecated. '
|
|
|
+ 'Restrict the databases available during the execution of '
|
|
|
+ '%s.%s with the `databases` attribute instead.'
|
|
|
+ ) % (cls.__module__, cls.__qualname__)
|
|
|
+ warnings.warn(msg, RemovedInDjango31Warning)
|
|
|
+ if allow_database_queries:
|
|
|
+ return {DEFAULT_DB_ALIAS}
|
|
|
+ return set()
|
|
|
|
|
|
|
|
|
class SimpleTestCase(unittest.TestCase):
|
|
@@ -153,9 +171,13 @@ class SimpleTestCase(unittest.TestCase):
|
|
|
_overridden_settings = None
|
|
|
_modified_settings = None
|
|
|
|
|
|
- # Tests shouldn't be allowed to query the database since
|
|
|
- # this base class doesn't enforce any isolation.
|
|
|
- allow_database_queries = False
|
|
|
+ databases = _SimpleTestCaseDatabasesDescriptor()
|
|
|
+ _disallowed_database_msg = (
|
|
|
+ 'Database queries are not allowed in SimpleTestCase subclasses. '
|
|
|
+ 'Either subclass TestCase or TransactionTestCase to ensure proper '
|
|
|
+ 'test isolation or add %(alias)r to %(test)s.databases to silence '
|
|
|
+ 'this failure.'
|
|
|
+ )
|
|
|
|
|
|
@classmethod
|
|
|
def setUpClass(cls):
|
|
@@ -166,19 +188,51 @@ class SimpleTestCase(unittest.TestCase):
|
|
|
if cls._modified_settings:
|
|
|
cls._cls_modified_context = modify_settings(cls._modified_settings)
|
|
|
cls._cls_modified_context.enable()
|
|
|
- if not cls.allow_database_queries:
|
|
|
- for alias in connections:
|
|
|
- connection = connections[alias]
|
|
|
- connection.cursor = _CursorFailure(cls.__name__, connection.cursor)
|
|
|
- connection.chunked_cursor = _CursorFailure(cls.__name__, connection.chunked_cursor)
|
|
|
+ cls._add_cursor_failures()
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _validate_databases(cls):
|
|
|
+ if cls.databases == '__all__':
|
|
|
+ return frozenset(connections)
|
|
|
+ for alias in cls.databases:
|
|
|
+ if alias not in connections:
|
|
|
+ message = '%s.%s.databases refers to %r which is not defined in settings.DATABASES.' % (
|
|
|
+ cls.__module__,
|
|
|
+ cls.__qualname__,
|
|
|
+ alias,
|
|
|
+ )
|
|
|
+ close_matches = get_close_matches(alias, list(connections))
|
|
|
+ if close_matches:
|
|
|
+ message += ' Did you mean %r?' % close_matches[0]
|
|
|
+ raise ImproperlyConfigured(message)
|
|
|
+ return frozenset(cls.databases)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _add_cursor_failures(cls):
|
|
|
+ cls.databases = cls._validate_databases()
|
|
|
+ for alias in connections:
|
|
|
+ if alias in cls.databases:
|
|
|
+ continue
|
|
|
+ connection = connections[alias]
|
|
|
+ message = cls._disallowed_database_msg % {
|
|
|
+ 'test': '%s.%s' % (cls.__module__, cls.__qualname__),
|
|
|
+ 'alias': alias,
|
|
|
+ }
|
|
|
+ connection.cursor = _CursorFailure(connection.cursor, message)
|
|
|
+ connection.chunked_cursor = _CursorFailure(connection.chunked_cursor, message)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _remove_cursor_failures(cls):
|
|
|
+ for alias in connections:
|
|
|
+ if alias in cls.databases:
|
|
|
+ continue
|
|
|
+ connection = connections[alias]
|
|
|
+ connection.cursor = connection.cursor.wrapped
|
|
|
+ connection.chunked_cursor = connection.chunked_cursor.wrapped
|
|
|
|
|
|
@classmethod
|
|
|
def tearDownClass(cls):
|
|
|
- if not cls.allow_database_queries:
|
|
|
- for alias in connections:
|
|
|
- connection = connections[alias]
|
|
|
- connection.cursor = connection.cursor.wrapped
|
|
|
- connection.chunked_cursor = connection.chunked_cursor.wrapped
|
|
|
+ cls._remove_cursor_failures()
|
|
|
if hasattr(cls, '_cls_modified_context'):
|
|
|
cls._cls_modified_context.disable()
|
|
|
delattr(cls, '_cls_modified_context')
|
|
@@ -806,6 +860,26 @@ class SimpleTestCase(unittest.TestCase):
|
|
|
self.fail(self._formatMessage(msg, standardMsg))
|
|
|
|
|
|
|
|
|
+class _TransactionTestCaseDatabasesDescriptor:
|
|
|
+ """Descriptor for TransactionTestCase.multi_db deprecation."""
|
|
|
+ msg = (
|
|
|
+ '`TransactionTestCase.multi_db` is deprecated. Databases available '
|
|
|
+ 'during this test can be defined using %s.%s.databases.'
|
|
|
+ )
|
|
|
+
|
|
|
+ def __get__(self, instance, cls=None):
|
|
|
+ try:
|
|
|
+ multi_db = cls.multi_db
|
|
|
+ except AttributeError:
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ msg = self.msg % (cls.__module__, cls.__qualname__)
|
|
|
+ warnings.warn(msg, RemovedInDjango31Warning)
|
|
|
+ if multi_db:
|
|
|
+ return set(connections)
|
|
|
+ return {DEFAULT_DB_ALIAS}
|
|
|
+
|
|
|
+
|
|
|
class TransactionTestCase(SimpleTestCase):
|
|
|
|
|
|
# Subclasses can ask for resetting of auto increment sequence before each
|
|
@@ -818,8 +892,12 @@ class TransactionTestCase(SimpleTestCase):
|
|
|
# Subclasses can define fixtures which will be automatically installed.
|
|
|
fixtures = None
|
|
|
|
|
|
- # Do the tests in this class query non-default databases?
|
|
|
- multi_db = False
|
|
|
+ databases = _TransactionTestCaseDatabasesDescriptor()
|
|
|
+ _disallowed_database_msg = (
|
|
|
+ 'Database queries to %(alias)r are not allowed in this test. Add '
|
|
|
+ '%(alias)r to %(test)s.databases to ensure proper test isolation '
|
|
|
+ 'and silence this failure.'
|
|
|
+ )
|
|
|
|
|
|
# If transactions aren't available, Django will serialize the database
|
|
|
# contents into a fixture during setup and flush and reload them
|
|
@@ -827,10 +905,6 @@ class TransactionTestCase(SimpleTestCase):
|
|
|
# This can be slow; this flag allows enabling on a per-case basis.
|
|
|
serialized_rollback = False
|
|
|
|
|
|
- # Since tests will be wrapped in a transaction, or serialized if they
|
|
|
- # are not available, we allow queries to be run.
|
|
|
- allow_database_queries = True
|
|
|
-
|
|
|
def _pre_setup(self):
|
|
|
"""
|
|
|
Perform pre-test setup:
|
|
@@ -870,15 +944,13 @@ class TransactionTestCase(SimpleTestCase):
|
|
|
|
|
|
@classmethod
|
|
|
def _databases_names(cls, include_mirrors=True):
|
|
|
- # If the test case has a multi_db=True flag, act on all databases,
|
|
|
- # including mirrors or not. Otherwise, just on the default DB.
|
|
|
- if cls.multi_db:
|
|
|
- return [
|
|
|
- alias for alias in connections
|
|
|
- if include_mirrors or not connections[alias].settings_dict['TEST']['MIRROR']
|
|
|
- ]
|
|
|
- else:
|
|
|
- return [DEFAULT_DB_ALIAS]
|
|
|
+ # Only consider allowed database aliases, including mirrors or not.
|
|
|
+ return [
|
|
|
+ alias for alias in connections
|
|
|
+ if alias in cls.databases and (
|
|
|
+ include_mirrors or not connections[alias].settings_dict['TEST']['MIRROR']
|
|
|
+ )
|
|
|
+ ]
|
|
|
|
|
|
def _reset_sequences(self, db_name):
|
|
|
conn = connections[db_name]
|
|
@@ -984,9 +1056,21 @@ class TransactionTestCase(SimpleTestCase):
|
|
|
func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
-def connections_support_transactions():
|
|
|
- """Return True if all connections support transactions."""
|
|
|
- return all(conn.features.supports_transactions for conn in connections.all())
|
|
|
+def connections_support_transactions(aliases=None):
|
|
|
+ """
|
|
|
+ Return whether or not all (or specified) connections support
|
|
|
+ transactions.
|
|
|
+ """
|
|
|
+ conns = connections.all() if aliases is None else (connections[alias] for alias in aliases)
|
|
|
+ return all(conn.features.supports_transactions for conn in conns)
|
|
|
+
|
|
|
+
|
|
|
+class _TestCaseDatabasesDescriptor(_TransactionTestCaseDatabasesDescriptor):
|
|
|
+ """Descriptor for TestCase.multi_db deprecation."""
|
|
|
+ msg = (
|
|
|
+ '`TestCase.multi_db` is deprecated. Databases available during this '
|
|
|
+ 'test can be defined using %s.%s.databases.'
|
|
|
+ )
|
|
|
|
|
|
|
|
|
class TestCase(TransactionTestCase):
|
|
@@ -1002,6 +1086,8 @@ class TestCase(TransactionTestCase):
|
|
|
On database backends with no transaction support, TestCase behaves as
|
|
|
TransactionTestCase.
|
|
|
"""
|
|
|
+ databases = _TestCaseDatabasesDescriptor()
|
|
|
+
|
|
|
@classmethod
|
|
|
def _enter_atomics(cls):
|
|
|
"""Open atomic blocks for multiple databases."""
|
|
@@ -1018,10 +1104,14 @@ class TestCase(TransactionTestCase):
|
|
|
transaction.set_rollback(True, using=db_name)
|
|
|
atomics[db_name].__exit__(None, None, None)
|
|
|
|
|
|
+ @classmethod
|
|
|
+ def _databases_support_transactions(cls):
|
|
|
+ return connections_support_transactions(cls.databases)
|
|
|
+
|
|
|
@classmethod
|
|
|
def setUpClass(cls):
|
|
|
super().setUpClass()
|
|
|
- if not connections_support_transactions():
|
|
|
+ if not cls._databases_support_transactions():
|
|
|
return
|
|
|
cls.cls_atomics = cls._enter_atomics()
|
|
|
|
|
@@ -1031,16 +1121,18 @@ class TestCase(TransactionTestCase):
|
|
|
call_command('loaddata', *cls.fixtures, **{'verbosity': 0, 'database': db_name})
|
|
|
except Exception:
|
|
|
cls._rollback_atomics(cls.cls_atomics)
|
|
|
+ cls._remove_cursor_failures()
|
|
|
raise
|
|
|
try:
|
|
|
cls.setUpTestData()
|
|
|
except Exception:
|
|
|
cls._rollback_atomics(cls.cls_atomics)
|
|
|
+ cls._remove_cursor_failures()
|
|
|
raise
|
|
|
|
|
|
@classmethod
|
|
|
def tearDownClass(cls):
|
|
|
- if connections_support_transactions():
|
|
|
+ if cls._databases_support_transactions():
|
|
|
cls._rollback_atomics(cls.cls_atomics)
|
|
|
for conn in connections.all():
|
|
|
conn.close()
|
|
@@ -1052,12 +1144,12 @@ class TestCase(TransactionTestCase):
|
|
|
pass
|
|
|
|
|
|
def _should_reload_connections(self):
|
|
|
- if connections_support_transactions():
|
|
|
+ if self._databases_support_transactions():
|
|
|
return False
|
|
|
return super()._should_reload_connections()
|
|
|
|
|
|
def _fixture_setup(self):
|
|
|
- if not connections_support_transactions():
|
|
|
+ if not self._databases_support_transactions():
|
|
|
# If the backend does not support transactions, we should reload
|
|
|
# class data before each test
|
|
|
self.setUpTestData()
|
|
@@ -1067,7 +1159,7 @@ class TestCase(TransactionTestCase):
|
|
|
self.atomics = self._enter_atomics()
|
|
|
|
|
|
def _fixture_teardown(self):
|
|
|
- if not connections_support_transactions():
|
|
|
+ if not self._databases_support_transactions():
|
|
|
return super()._fixture_teardown()
|
|
|
try:
|
|
|
for db_name in reversed(self._databases_names()):
|