tests.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. import os
  2. import re
  3. import tempfile
  4. import threading
  5. import unittest
  6. from contextlib import contextmanager
  7. from pathlib import Path
  8. from unittest import mock
  9. from django.core.exceptions import ImproperlyConfigured
  10. from django.db import (
  11. DEFAULT_DB_ALIAS,
  12. NotSupportedError,
  13. connection,
  14. connections,
  15. transaction,
  16. )
  17. from django.db.models import Aggregate, Avg, StdDev, Sum, Variance
  18. from django.db.utils import ConnectionHandler
  19. from django.test import SimpleTestCase, TestCase, TransactionTestCase, override_settings
  20. from django.test.utils import CaptureQueriesContext, isolate_apps
  21. from ..models import Item, Object, Square
  22. @unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests")
  23. class Tests(TestCase):
  24. longMessage = True
  25. def test_aggregation(self):
  26. """Raise NotSupportedError when aggregating on date/time fields."""
  27. for aggregate in (Sum, Avg, Variance, StdDev):
  28. with self.assertRaises(NotSupportedError):
  29. Item.objects.aggregate(aggregate("time"))
  30. with self.assertRaises(NotSupportedError):
  31. Item.objects.aggregate(aggregate("date"))
  32. with self.assertRaises(NotSupportedError):
  33. Item.objects.aggregate(aggregate("last_modified"))
  34. with self.assertRaises(NotSupportedError):
  35. Item.objects.aggregate(
  36. **{
  37. "complex": aggregate("last_modified")
  38. + aggregate("last_modified")
  39. }
  40. )
  41. def test_distinct_aggregation(self):
  42. class DistinctAggregate(Aggregate):
  43. allow_distinct = True
  44. aggregate = DistinctAggregate("first", "second", distinct=True)
  45. msg = (
  46. "SQLite doesn't support DISTINCT on aggregate functions accepting "
  47. "multiple arguments."
  48. )
  49. with self.assertRaisesMessage(NotSupportedError, msg):
  50. connection.ops.check_expression_support(aggregate)
  51. def test_distinct_aggregation_multiple_args_no_distinct(self):
  52. # Aggregate functions accept multiple arguments when DISTINCT isn't
  53. # used, e.g. GROUP_CONCAT().
  54. class DistinctAggregate(Aggregate):
  55. allow_distinct = True
  56. aggregate = DistinctAggregate("first", "second", distinct=False)
  57. connection.ops.check_expression_support(aggregate)
  58. def test_memory_db_test_name(self):
  59. """A named in-memory db should be allowed where supported."""
  60. from django.db.backends.sqlite3.base import DatabaseWrapper
  61. settings_dict = {
  62. "TEST": {
  63. "NAME": "file:memorydb_test?mode=memory&cache=shared",
  64. }
  65. }
  66. creation = DatabaseWrapper(settings_dict).creation
  67. self.assertEqual(
  68. creation._get_test_db_name(),
  69. creation.connection.settings_dict["TEST"]["NAME"],
  70. )
  71. def test_regexp_function(self):
  72. tests = (
  73. ("test", r"[0-9]+", False),
  74. ("test", r"[a-z]+", True),
  75. ("test", None, None),
  76. (None, r"[a-z]+", None),
  77. (None, None, None),
  78. )
  79. for string, pattern, expected in tests:
  80. with self.subTest((string, pattern)):
  81. with connection.cursor() as cursor:
  82. cursor.execute("SELECT %s REGEXP %s", [string, pattern])
  83. value = cursor.fetchone()[0]
  84. value = bool(value) if value in {0, 1} else value
  85. self.assertIs(value, expected)
  86. def test_pathlib_name(self):
  87. with tempfile.TemporaryDirectory() as tmp:
  88. settings_dict = {
  89. "default": {
  90. "ENGINE": "django.db.backends.sqlite3",
  91. "NAME": Path(tmp) / "test.db",
  92. },
  93. }
  94. connections = ConnectionHandler(settings_dict)
  95. connections["default"].ensure_connection()
  96. connections["default"].close()
  97. self.assertTrue(os.path.isfile(os.path.join(tmp, "test.db")))
  98. @mock.patch.object(connection, "get_database_version", return_value=(3, 26))
  99. def test_check_database_version_supported(self, mocked_get_database_version):
  100. msg = "SQLite 3.27 or later is required (found 3.26)."
  101. with self.assertRaisesMessage(NotSupportedError, msg):
  102. connection.check_database_version_supported()
  103. self.assertTrue(mocked_get_database_version.called)
  104. @unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests")
  105. @isolate_apps("backends")
  106. class SchemaTests(TransactionTestCase):
  107. available_apps = ["backends"]
  108. def test_autoincrement(self):
  109. """
  110. auto_increment fields are created with the AUTOINCREMENT keyword
  111. in order to be monotonically increasing (#10164).
  112. """
  113. with connection.schema_editor(collect_sql=True) as editor:
  114. editor.create_model(Square)
  115. statements = editor.collected_sql
  116. match = re.search('"id" ([^,]+),', statements[0])
  117. self.assertIsNotNone(match)
  118. self.assertEqual(
  119. "integer NOT NULL PRIMARY KEY AUTOINCREMENT",
  120. match[1],
  121. "Wrong SQL used to create an auto-increment column on SQLite",
  122. )
  123. def test_disable_constraint_checking_failure_disallowed(self):
  124. """
  125. SQLite schema editor is not usable within an outer transaction if
  126. foreign key constraint checks are not disabled beforehand.
  127. """
  128. msg = (
  129. "SQLite schema editor cannot be used while foreign key "
  130. "constraint checks are enabled. Make sure to disable them "
  131. "before entering a transaction.atomic() context because "
  132. "SQLite does not support disabling them in the middle of "
  133. "a multi-statement transaction."
  134. )
  135. with self.assertRaisesMessage(NotSupportedError, msg):
  136. with transaction.atomic(), connection.schema_editor(atomic=True):
  137. pass
  138. def test_constraint_checks_disabled_atomic_allowed(self):
  139. """
  140. SQLite schema editor is usable within an outer transaction as long as
  141. foreign key constraints checks are disabled beforehand.
  142. """
  143. def constraint_checks_enabled():
  144. with connection.cursor() as cursor:
  145. return bool(cursor.execute("PRAGMA foreign_keys").fetchone()[0])
  146. with connection.constraint_checks_disabled(), transaction.atomic():
  147. with connection.schema_editor(atomic=True):
  148. self.assertFalse(constraint_checks_enabled())
  149. self.assertFalse(constraint_checks_enabled())
  150. self.assertTrue(constraint_checks_enabled())
  151. @unittest.skipUnless(connection.vendor == "sqlite", "Test only for SQLite")
  152. @override_settings(DEBUG=True)
  153. class LastExecutedQueryTest(TestCase):
  154. def test_no_interpolation(self):
  155. # This shouldn't raise an exception (#17158)
  156. query = "SELECT strftime('%Y', 'now');"
  157. with connection.cursor() as cursor:
  158. cursor.execute(query)
  159. self.assertEqual(connection.queries[-1]["sql"], query)
  160. def test_parameter_quoting(self):
  161. # The implementation of last_executed_queries isn't optimal. It's
  162. # worth testing that parameters are quoted (#14091).
  163. query = "SELECT %s"
  164. params = ["\"'\\"]
  165. with connection.cursor() as cursor:
  166. cursor.execute(query, params)
  167. # Note that the single quote is repeated
  168. substituted = "SELECT '\"''\\'"
  169. self.assertEqual(connection.queries[-1]["sql"], substituted)
  170. def test_large_number_of_parameters(self):
  171. # If SQLITE_MAX_VARIABLE_NUMBER (default = 999) has been changed to be
  172. # greater than SQLITE_MAX_COLUMN (default = 2000), last_executed_query
  173. # can hit the SQLITE_MAX_COLUMN limit (#26063).
  174. with connection.cursor() as cursor:
  175. sql = "SELECT MAX(%s)" % ", ".join(["%s"] * 2001)
  176. params = list(range(2001))
  177. # This should not raise an exception.
  178. cursor.db.ops.last_executed_query(cursor.cursor, sql, params)
  179. @unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests")
  180. class EscapingChecks(TestCase):
  181. """
  182. All tests in this test case are also run with settings.DEBUG=True in
  183. EscapingChecksDebug test case, to also test CursorDebugWrapper.
  184. """
  185. def test_parameter_escaping(self):
  186. # '%s' escaping support for sqlite3 (#13648).
  187. with connection.cursor() as cursor:
  188. cursor.execute("select strftime('%s', date('now'))")
  189. response = cursor.fetchall()[0][0]
  190. # response should be an non-zero integer
  191. self.assertTrue(int(response))
  192. @unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests")
  193. @override_settings(DEBUG=True)
  194. class EscapingChecksDebug(EscapingChecks):
  195. pass
  196. @unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests")
  197. class ThreadSharing(TransactionTestCase):
  198. available_apps = ["backends"]
  199. def test_database_sharing_in_threads(self):
  200. thread_connections = []
  201. def create_object():
  202. Object.objects.create()
  203. thread_connections.append(connections[DEFAULT_DB_ALIAS].connection)
  204. main_connection = connections[DEFAULT_DB_ALIAS].connection
  205. try:
  206. create_object()
  207. thread = threading.Thread(target=create_object)
  208. thread.start()
  209. thread.join()
  210. self.assertEqual(Object.objects.count(), 2)
  211. finally:
  212. for conn in thread_connections:
  213. if conn is not main_connection:
  214. conn.close()
  215. @unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests")
  216. class TestTransactionMode(SimpleTestCase):
  217. databases = {"default"}
  218. def test_default_transaction_mode(self):
  219. with CaptureQueriesContext(connection) as captured_queries:
  220. with transaction.atomic():
  221. pass
  222. begin_query, commit_query = captured_queries
  223. self.assertEqual(begin_query["sql"], "BEGIN")
  224. self.assertEqual(commit_query["sql"], "COMMIT")
  225. def test_invalid_transaction_mode(self):
  226. msg = (
  227. "settings.DATABASES['default']['OPTIONS']['transaction_mode'] is "
  228. "improperly configured to 'invalid'. Use one of 'DEFERRED', 'EXCLUSIVE', "
  229. "'IMMEDIATE', or None."
  230. )
  231. with self.change_transaction_mode("invalid") as new_connection:
  232. with self.assertRaisesMessage(ImproperlyConfigured, msg):
  233. new_connection.ensure_connection()
  234. def test_valid_transaction_modes(self):
  235. valid_transaction_modes = ("deferred", "immediate", "exclusive")
  236. for transaction_mode in valid_transaction_modes:
  237. with (
  238. self.subTest(transaction_mode=transaction_mode),
  239. self.change_transaction_mode(transaction_mode) as new_connection,
  240. CaptureQueriesContext(new_connection) as captured_queries,
  241. ):
  242. new_connection.set_autocommit(
  243. False, force_begin_transaction_with_broken_autocommit=True
  244. )
  245. new_connection.commit()
  246. expected_transaction_mode = transaction_mode.upper()
  247. begin_sql = captured_queries[0]["sql"]
  248. self.assertEqual(begin_sql, f"BEGIN {expected_transaction_mode}")
  249. @contextmanager
  250. def change_transaction_mode(self, transaction_mode):
  251. new_connection = connection.copy()
  252. new_connection.settings_dict["OPTIONS"] = {
  253. **new_connection.settings_dict["OPTIONS"],
  254. "transaction_mode": transaction_mode,
  255. }
  256. try:
  257. yield new_connection
  258. finally:
  259. new_connection.close()