2
0

tests.py 79 KB


  1. import os
  2. import unittest
  3. import warnings
  4. from io import StringIO
  5. from unittest import mock
  6. from django.conf import STATICFILES_STORAGE_ALIAS, settings
  7. from django.contrib.staticfiles.finders import get_finder, get_finders
  8. from django.contrib.staticfiles.storage import staticfiles_storage
  9. from django.core.exceptions import ImproperlyConfigured
  10. from django.core.files.storage import default_storage
  11. from django.db import (
  12. IntegrityError,
  13. connection,
  14. connections,
  15. models,
  16. router,
  17. transaction,
  18. )
  19. from django.forms import (
  20. CharField,
  21. EmailField,
  22. Form,
  23. IntegerField,
  24. ValidationError,
  25. formset_factory,
  26. )
  27. from django.http import HttpResponse
  28. from django.template.loader import render_to_string
  29. from django.test import (
  30. SimpleTestCase,
  31. TestCase,
  32. TransactionTestCase,
  33. skipIfDBFeature,
  34. skipUnlessDBFeature,
  35. )
  36. from django.test.html import HTMLParseError, parse_html
  37. from django.test.testcases import DatabaseOperationForbidden
  38. from django.test.utils import (
  39. CaptureQueriesContext,
  40. TestContextDecorator,
  41. isolate_apps,
  42. override_settings,
  43. setup_test_environment,
  44. )
  45. from django.urls import NoReverseMatch, path, reverse, reverse_lazy
  46. from django.utils.html import VOID_ELEMENTS
  47. from django.utils.version import PY311
  48. from .models import Car, Person, PossessedCar
  49. from .views import empty_response
  50. class SkippingTestCase(SimpleTestCase):
  51. def _assert_skipping(self, func, expected_exc, msg=None):
  52. try:
  53. if msg is not None:
  54. with self.assertRaisesMessage(expected_exc, msg):
  55. func()
  56. else:
  57. with self.assertRaises(expected_exc):
  58. func()
  59. except unittest.SkipTest:
  60. self.fail("%s should not result in a skipped test." % func.__name__)
  61. def test_skip_unless_db_feature(self):
  62. """
  63. Testing the django.test.skipUnlessDBFeature decorator.
  64. """
  65. # Total hack, but it works, just want an attribute that's always true.
  66. @skipUnlessDBFeature("__class__")
  67. def test_func():
  68. raise ValueError
  69. @skipUnlessDBFeature("notprovided")
  70. def test_func2():
  71. raise ValueError
  72. @skipUnlessDBFeature("__class__", "__class__")
  73. def test_func3():
  74. raise ValueError
  75. @skipUnlessDBFeature("__class__", "notprovided")
  76. def test_func4():
  77. raise ValueError
  78. self._assert_skipping(test_func, ValueError)
  79. self._assert_skipping(test_func2, unittest.SkipTest)
  80. self._assert_skipping(test_func3, ValueError)
  81. self._assert_skipping(test_func4, unittest.SkipTest)
  82. class SkipTestCase(SimpleTestCase):
  83. @skipUnlessDBFeature("missing")
  84. def test_foo(self):
  85. pass
  86. self._assert_skipping(
  87. SkipTestCase("test_foo").test_foo,
  88. ValueError,
  89. "skipUnlessDBFeature cannot be used on test_foo (test_utils.tests."
  90. "SkippingTestCase.test_skip_unless_db_feature.<locals>.SkipTestCase%s) "
  91. "as SkippingTestCase.test_skip_unless_db_feature.<locals>.SkipTestCase "
  92. "doesn't allow queries against the 'default' database."
  93. # Python 3.11 uses fully qualified test name in the output.
  94. % (".test_foo" if PY311 else ""),
  95. )
  96. def test_skip_if_db_feature(self):
  97. """
  98. Testing the django.test.skipIfDBFeature decorator.
  99. """
  100. @skipIfDBFeature("__class__")
  101. def test_func():
  102. raise ValueError
  103. @skipIfDBFeature("notprovided")
  104. def test_func2():
  105. raise ValueError
  106. @skipIfDBFeature("__class__", "__class__")
  107. def test_func3():
  108. raise ValueError
  109. @skipIfDBFeature("__class__", "notprovided")
  110. def test_func4():
  111. raise ValueError
  112. @skipIfDBFeature("notprovided", "notprovided")
  113. def test_func5():
  114. raise ValueError
  115. self._assert_skipping(test_func, unittest.SkipTest)
  116. self._assert_skipping(test_func2, ValueError)
  117. self._assert_skipping(test_func3, unittest.SkipTest)
  118. self._assert_skipping(test_func4, unittest.SkipTest)
  119. self._assert_skipping(test_func5, ValueError)
  120. class SkipTestCase(SimpleTestCase):
  121. @skipIfDBFeature("missing")
  122. def test_foo(self):
  123. pass
  124. self._assert_skipping(
  125. SkipTestCase("test_foo").test_foo,
  126. ValueError,
  127. "skipIfDBFeature cannot be used on test_foo (test_utils.tests."
  128. "SkippingTestCase.test_skip_if_db_feature.<locals>.SkipTestCase%s) "
  129. "as SkippingTestCase.test_skip_if_db_feature.<locals>.SkipTestCase "
  130. "doesn't allow queries against the 'default' database."
  131. # Python 3.11 uses fully qualified test name in the output.
  132. % (".test_foo" if PY311 else ""),
  133. )
  134. class SkippingClassTestCase(TestCase):
  135. def test_skip_class_unless_db_feature(self):
  136. @skipUnlessDBFeature("__class__")
  137. class NotSkippedTests(TestCase):
  138. def test_dummy(self):
  139. return
  140. @skipUnlessDBFeature("missing")
  141. @skipIfDBFeature("__class__")
  142. class SkippedTests(TestCase):
  143. def test_will_be_skipped(self):
  144. self.fail("We should never arrive here.")
  145. @skipIfDBFeature("__dict__")
  146. class SkippedTestsSubclass(SkippedTests):
  147. pass
  148. test_suite = unittest.TestSuite()
  149. test_suite.addTest(NotSkippedTests("test_dummy"))
  150. try:
  151. test_suite.addTest(SkippedTests("test_will_be_skipped"))
  152. test_suite.addTest(SkippedTestsSubclass("test_will_be_skipped"))
  153. except unittest.SkipTest:
  154. self.fail("SkipTest should not be raised here.")
  155. result = unittest.TextTestRunner(stream=StringIO()).run(test_suite)
  156. self.assertEqual(result.testsRun, 3)
  157. self.assertEqual(len(result.skipped), 2)
  158. self.assertEqual(result.skipped[0][1], "Database has feature(s) __class__")
  159. self.assertEqual(result.skipped[1][1], "Database has feature(s) __class__")
  160. def test_missing_default_databases(self):
  161. @skipIfDBFeature("missing")
  162. class MissingDatabases(SimpleTestCase):
  163. def test_assertion_error(self):
  164. pass
  165. suite = unittest.TestSuite()
  166. try:
  167. suite.addTest(MissingDatabases("test_assertion_error"))
  168. except unittest.SkipTest:
  169. self.fail("SkipTest should not be raised at this stage")
  170. runner = unittest.TextTestRunner(stream=StringIO())
  171. msg = (
  172. "skipIfDBFeature cannot be used on <class 'test_utils.tests."
  173. "SkippingClassTestCase.test_missing_default_databases.<locals>."
  174. "MissingDatabases'> as it doesn't allow queries against the "
  175. "'default' database."
  176. )
  177. with self.assertRaisesMessage(ValueError, msg):
  178. runner.run(suite)
  179. @override_settings(ROOT_URLCONF="test_utils.urls")
  180. class AssertNumQueriesTests(TestCase):
  181. def test_assert_num_queries(self):
  182. def test_func():
  183. raise ValueError
  184. with self.assertRaises(ValueError):
  185. self.assertNumQueries(2, test_func)
  186. def test_assert_num_queries_with_client(self):
  187. person = Person.objects.create(name="test")
  188. self.assertNumQueries(
  189. 1, self.client.get, "/test_utils/get_person/%s/" % person.pk
  190. )
  191. self.assertNumQueries(
  192. 1, self.client.get, "/test_utils/get_person/%s/" % person.pk
  193. )
  194. def test_func():
  195. self.client.get("/test_utils/get_person/%s/" % person.pk)
  196. self.client.get("/test_utils/get_person/%s/" % person.pk)
  197. self.assertNumQueries(2, test_func)
  198. class AssertNumQueriesUponConnectionTests(TransactionTestCase):
  199. available_apps = []
  200. def test_ignores_connection_configuration_queries(self):
  201. real_ensure_connection = connection.ensure_connection
  202. connection.close()
  203. def make_configuration_query():
  204. is_opening_connection = connection.connection is None
  205. real_ensure_connection()
  206. if is_opening_connection:
  207. # Avoid infinite recursion. Creating a cursor calls
  208. # ensure_connection() which is currently mocked by this method.
  209. with connection.cursor() as cursor:
  210. cursor.execute("SELECT 1" + connection.features.bare_select_suffix)
  211. ensure_connection = (
  212. "django.db.backends.base.base.BaseDatabaseWrapper.ensure_connection"
  213. )
  214. with mock.patch(ensure_connection, side_effect=make_configuration_query):
  215. with self.assertNumQueries(1):
  216. list(Car.objects.all())
  217. class AssertQuerySetEqualTests(TestCase):
  218. @classmethod
  219. def setUpTestData(cls):
  220. cls.p1 = Person.objects.create(name="p1")
  221. cls.p2 = Person.objects.create(name="p2")
  222. def test_empty(self):
  223. self.assertQuerySetEqual(Person.objects.filter(name="p3"), [])
  224. def test_ordered(self):
  225. self.assertQuerySetEqual(
  226. Person.objects.order_by("name"),
  227. [self.p1, self.p2],
  228. )
  229. def test_unordered(self):
  230. self.assertQuerySetEqual(
  231. Person.objects.order_by("name"), [self.p2, self.p1], ordered=False
  232. )
  233. def test_queryset(self):
  234. self.assertQuerySetEqual(
  235. Person.objects.order_by("name"),
  236. Person.objects.order_by("name"),
  237. )
  238. def test_flat_values_list(self):
  239. self.assertQuerySetEqual(
  240. Person.objects.order_by("name").values_list("name", flat=True),
  241. ["p1", "p2"],
  242. )
  243. def test_transform(self):
  244. self.assertQuerySetEqual(
  245. Person.objects.order_by("name"),
  246. [self.p1.pk, self.p2.pk],
  247. transform=lambda x: x.pk,
  248. )
  249. def test_repr_transform(self):
  250. self.assertQuerySetEqual(
  251. Person.objects.order_by("name"),
  252. [repr(self.p1), repr(self.p2)],
  253. transform=repr,
  254. )
  255. def test_undefined_order(self):
  256. # Using an unordered queryset with more than one ordered value
  257. # is an error.
  258. msg = (
  259. "Trying to compare non-ordered queryset against more than one "
  260. "ordered value."
  261. )
  262. with self.assertRaisesMessage(ValueError, msg):
  263. self.assertQuerySetEqual(
  264. Person.objects.all(),
  265. [self.p1, self.p2],
  266. )
  267. # No error for one value.
  268. self.assertQuerySetEqual(Person.objects.filter(name="p1"), [self.p1])
  269. def test_repeated_values(self):
  270. """
  271. assertQuerySetEqual checks the number of appearance of each item
  272. when used with option ordered=False.
  273. """
  274. batmobile = Car.objects.create(name="Batmobile")
  275. k2000 = Car.objects.create(name="K 2000")
  276. PossessedCar.objects.bulk_create(
  277. [
  278. PossessedCar(car=batmobile, belongs_to=self.p1),
  279. PossessedCar(car=batmobile, belongs_to=self.p1),
  280. PossessedCar(car=k2000, belongs_to=self.p1),
  281. PossessedCar(car=k2000, belongs_to=self.p1),
  282. PossessedCar(car=k2000, belongs_to=self.p1),
  283. PossessedCar(car=k2000, belongs_to=self.p1),
  284. ]
  285. )
  286. with self.assertRaises(AssertionError):
  287. self.assertQuerySetEqual(
  288. self.p1.cars.all(), [batmobile, k2000], ordered=False
  289. )
  290. self.assertQuerySetEqual(
  291. self.p1.cars.all(), [batmobile] * 2 + [k2000] * 4, ordered=False
  292. )
  293. def test_maxdiff(self):
  294. names = ["Joe Smith %s" % i for i in range(20)]
  295. Person.objects.bulk_create([Person(name=name) for name in names])
  296. names.append("Extra Person")
  297. with self.assertRaises(AssertionError) as ctx:
  298. self.assertQuerySetEqual(
  299. Person.objects.filter(name__startswith="Joe"),
  300. names,
  301. ordered=False,
  302. transform=lambda p: p.name,
  303. )
  304. self.assertIn("Set self.maxDiff to None to see it.", str(ctx.exception))
  305. original = self.maxDiff
  306. self.maxDiff = None
  307. try:
  308. with self.assertRaises(AssertionError) as ctx:
  309. self.assertQuerySetEqual(
  310. Person.objects.filter(name__startswith="Joe"),
  311. names,
  312. ordered=False,
  313. transform=lambda p: p.name,
  314. )
  315. finally:
  316. self.maxDiff = original
  317. exception_msg = str(ctx.exception)
  318. self.assertNotIn("Set self.maxDiff to None to see it.", exception_msg)
  319. for name in names:
  320. self.assertIn(name, exception_msg)
  321. @override_settings(ROOT_URLCONF="test_utils.urls")
  322. class CaptureQueriesContextManagerTests(TestCase):
  323. @classmethod
  324. def setUpTestData(cls):
  325. cls.person_pk = str(Person.objects.create(name="test").pk)
  326. def test_simple(self):
  327. with CaptureQueriesContext(connection) as captured_queries:
  328. Person.objects.get(pk=self.person_pk)
  329. self.assertEqual(len(captured_queries), 1)
  330. self.assertIn(self.person_pk, captured_queries[0]["sql"])
  331. with CaptureQueriesContext(connection) as captured_queries:
  332. pass
  333. self.assertEqual(0, len(captured_queries))
  334. def test_within(self):
  335. with CaptureQueriesContext(connection) as captured_queries:
  336. Person.objects.get(pk=self.person_pk)
  337. self.assertEqual(len(captured_queries), 1)
  338. self.assertIn(self.person_pk, captured_queries[0]["sql"])
  339. def test_nested(self):
  340. with CaptureQueriesContext(connection) as captured_queries:
  341. Person.objects.count()
  342. with CaptureQueriesContext(connection) as nested_captured_queries:
  343. Person.objects.count()
  344. self.assertEqual(1, len(nested_captured_queries))
  345. self.assertEqual(2, len(captured_queries))
  346. def test_failure(self):
  347. with self.assertRaises(TypeError):
  348. with CaptureQueriesContext(connection):
  349. raise TypeError
  350. def test_with_client(self):
  351. with CaptureQueriesContext(connection) as captured_queries:
  352. self.client.get("/test_utils/get_person/%s/" % self.person_pk)
  353. self.assertEqual(len(captured_queries), 1)
  354. self.assertIn(self.person_pk, captured_queries[0]["sql"])
  355. with CaptureQueriesContext(connection) as captured_queries:
  356. self.client.get("/test_utils/get_person/%s/" % self.person_pk)
  357. self.assertEqual(len(captured_queries), 1)
  358. self.assertIn(self.person_pk, captured_queries[0]["sql"])
  359. with CaptureQueriesContext(connection) as captured_queries:
  360. self.client.get("/test_utils/get_person/%s/" % self.person_pk)
  361. self.client.get("/test_utils/get_person/%s/" % self.person_pk)
  362. self.assertEqual(len(captured_queries), 2)
  363. self.assertIn(self.person_pk, captured_queries[0]["sql"])
  364. self.assertIn(self.person_pk, captured_queries[1]["sql"])
  365. @override_settings(ROOT_URLCONF="test_utils.urls")
  366. class AssertNumQueriesContextManagerTests(TestCase):
  367. def test_simple(self):
  368. with self.assertNumQueries(0):
  369. pass
  370. with self.assertNumQueries(1):
  371. Person.objects.count()
  372. with self.assertNumQueries(2):
  373. Person.objects.count()
  374. Person.objects.count()
  375. def test_failure(self):
  376. msg = "1 != 2 : 1 queries executed, 2 expected\nCaptured queries were:\n1."
  377. with self.assertRaisesMessage(AssertionError, msg):
  378. with self.assertNumQueries(2):
  379. Person.objects.count()
  380. with self.assertRaises(TypeError):
  381. with self.assertNumQueries(4000):
  382. raise TypeError
  383. def test_with_client(self):
  384. person = Person.objects.create(name="test")
  385. with self.assertNumQueries(1):
  386. self.client.get("/test_utils/get_person/%s/" % person.pk)
  387. with self.assertNumQueries(1):
  388. self.client.get("/test_utils/get_person/%s/" % person.pk)
  389. with self.assertNumQueries(2):
  390. self.client.get("/test_utils/get_person/%s/" % person.pk)
  391. self.client.get("/test_utils/get_person/%s/" % person.pk)
  392. @override_settings(ROOT_URLCONF="test_utils.urls")
  393. class AssertTemplateUsedContextManagerTests(SimpleTestCase):
  394. def test_usage(self):
  395. with self.assertTemplateUsed("template_used/base.html"):
  396. render_to_string("template_used/base.html")
  397. with self.assertTemplateUsed(template_name="template_used/base.html"):
  398. render_to_string("template_used/base.html")
  399. with self.assertTemplateUsed("template_used/base.html"):
  400. render_to_string("template_used/include.html")
  401. with self.assertTemplateUsed("template_used/base.html"):
  402. render_to_string("template_used/extends.html")
  403. with self.assertTemplateUsed("template_used/base.html"):
  404. render_to_string("template_used/base.html")
  405. render_to_string("template_used/base.html")
  406. def test_nested_usage(self):
  407. with self.assertTemplateUsed("template_used/base.html"):
  408. with self.assertTemplateUsed("template_used/include.html"):
  409. render_to_string("template_used/include.html")
  410. with self.assertTemplateUsed("template_used/extends.html"):
  411. with self.assertTemplateUsed("template_used/base.html"):
  412. render_to_string("template_used/extends.html")
  413. with self.assertTemplateUsed("template_used/base.html"):
  414. with self.assertTemplateUsed("template_used/alternative.html"):
  415. render_to_string("template_used/alternative.html")
  416. render_to_string("template_used/base.html")
  417. with self.assertTemplateUsed("template_used/base.html"):
  418. render_to_string("template_used/extends.html")
  419. with self.assertTemplateNotUsed("template_used/base.html"):
  420. render_to_string("template_used/alternative.html")
  421. render_to_string("template_used/base.html")
  422. def test_not_used(self):
  423. with self.assertTemplateNotUsed("template_used/base.html"):
  424. pass
  425. with self.assertTemplateNotUsed("template_used/alternative.html"):
  426. pass
  427. def test_error_message(self):
  428. msg = "No templates used to render the response"
  429. with self.assertRaisesMessage(AssertionError, msg):
  430. with self.assertTemplateUsed("template_used/base.html"):
  431. pass
  432. with self.assertRaisesMessage(AssertionError, msg):
  433. with self.assertTemplateUsed(template_name="template_used/base.html"):
  434. pass
  435. msg2 = (
  436. "Template 'template_used/base.html' was not a template used to render "
  437. "the response. Actual template(s) used: template_used/alternative.html"
  438. )
  439. with self.assertRaisesMessage(AssertionError, msg2):
  440. with self.assertTemplateUsed("template_used/base.html"):
  441. render_to_string("template_used/alternative.html")
  442. with self.assertRaisesMessage(
  443. AssertionError, "No templates used to render the response"
  444. ):
  445. response = self.client.get("/test_utils/no_template_used/")
  446. self.assertTemplateUsed(response, "template_used/base.html")
  447. def test_msg_prefix(self):
  448. msg_prefix = "Prefix"
  449. msg = f"{msg_prefix}: No templates used to render the response"
  450. with self.assertRaisesMessage(AssertionError, msg):
  451. with self.assertTemplateUsed(
  452. "template_used/base.html", msg_prefix=msg_prefix
  453. ):
  454. pass
  455. with self.assertRaisesMessage(AssertionError, msg):
  456. with self.assertTemplateUsed(
  457. template_name="template_used/base.html",
  458. msg_prefix=msg_prefix,
  459. ):
  460. pass
  461. msg = (
  462. f"{msg_prefix}: Template 'template_used/base.html' was not a "
  463. f"template used to render the response. Actual template(s) used: "
  464. f"template_used/alternative.html"
  465. )
  466. with self.assertRaisesMessage(AssertionError, msg):
  467. with self.assertTemplateUsed(
  468. "template_used/base.html", msg_prefix=msg_prefix
  469. ):
  470. render_to_string("template_used/alternative.html")
  471. def test_count(self):
  472. with self.assertTemplateUsed("template_used/base.html", count=2):
  473. render_to_string("template_used/base.html")
  474. render_to_string("template_used/base.html")
  475. msg = (
  476. "Template 'template_used/base.html' was expected to be rendered "
  477. "3 time(s) but was actually rendered 2 time(s)."
  478. )
  479. with self.assertRaisesMessage(AssertionError, msg):
  480. with self.assertTemplateUsed("template_used/base.html", count=3):
  481. render_to_string("template_used/base.html")
  482. render_to_string("template_used/base.html")
  483. def test_failure(self):
  484. msg = "response and/or template_name argument must be provided"
  485. with self.assertRaisesMessage(TypeError, msg):
  486. with self.assertTemplateUsed():
  487. pass
  488. msg = "No templates used to render the response"
  489. with self.assertRaisesMessage(AssertionError, msg):
  490. with self.assertTemplateUsed(""):
  491. pass
  492. with self.assertRaisesMessage(AssertionError, msg):
  493. with self.assertTemplateUsed(""):
  494. render_to_string("template_used/base.html")
  495. with self.assertRaisesMessage(AssertionError, msg):
  496. with self.assertTemplateUsed(template_name=""):
  497. pass
  498. msg = (
  499. "Template 'template_used/base.html' was not a template used to "
  500. "render the response. Actual template(s) used: "
  501. "template_used/alternative.html"
  502. )
  503. with self.assertRaisesMessage(AssertionError, msg):
  504. with self.assertTemplateUsed("template_used/base.html"):
  505. render_to_string("template_used/alternative.html")
  506. def test_assert_used_on_http_response(self):
  507. response = HttpResponse()
  508. msg = "%s() is only usable on responses fetched using the Django test Client."
  509. with self.assertRaisesMessage(ValueError, msg % "assertTemplateUsed"):
  510. self.assertTemplateUsed(response, "template.html")
  511. with self.assertRaisesMessage(ValueError, msg % "assertTemplateNotUsed"):
  512. self.assertTemplateNotUsed(response, "template.html")
  513. class HTMLEqualTests(SimpleTestCase):
  514. def test_html_parser(self):
  515. element = parse_html("<div><p>Hello</p></div>")
  516. self.assertEqual(len(element.children), 1)
  517. self.assertEqual(element.children[0].name, "p")
  518. self.assertEqual(element.children[0].children[0], "Hello")
  519. parse_html("<p>")
  520. parse_html("<p attr>")
  521. dom = parse_html("<p>foo")
  522. self.assertEqual(len(dom.children), 1)
  523. self.assertEqual(dom.name, "p")
  524. self.assertEqual(dom[0], "foo")
  525. def test_parse_html_in_script(self):
  526. parse_html('<script>var a = "<p" + ">";</script>')
  527. parse_html(
  528. """
  529. <script>
  530. var js_sha_link='<p>***</p>';
  531. </script>
  532. """
  533. )
  534. # script content will be parsed to text
  535. dom = parse_html(
  536. """
  537. <script><p>foo</p> '</scr'+'ipt>' <span>bar</span></script>
  538. """
  539. )
  540. self.assertEqual(len(dom.children), 1)
  541. self.assertEqual(dom.children[0], "<p>foo</p> '</scr'+'ipt>' <span>bar</span>")
  542. def test_void_elements(self):
  543. for tag in VOID_ELEMENTS:
  544. with self.subTest(tag):
  545. dom = parse_html("<p>Hello <%s> world</p>" % tag)
  546. self.assertEqual(len(dom.children), 3)
  547. self.assertEqual(dom[0], "Hello")
  548. self.assertEqual(dom[1].name, tag)
  549. self.assertEqual(dom[2], "world")
  550. dom = parse_html("<p>Hello <%s /> world</p>" % tag)
  551. self.assertEqual(len(dom.children), 3)
  552. self.assertEqual(dom[0], "Hello")
  553. self.assertEqual(dom[1].name, tag)
  554. self.assertEqual(dom[2], "world")
  555. def test_simple_equal_html(self):
  556. self.assertHTMLEqual("", "")
  557. self.assertHTMLEqual("<p></p>", "<p></p>")
  558. self.assertHTMLEqual("<p></p>", " <p> </p> ")
  559. self.assertHTMLEqual("<div><p>Hello</p></div>", "<div><p>Hello</p></div>")
  560. self.assertHTMLEqual("<div><p>Hello</p></div>", "<div> <p>Hello</p> </div>")
  561. self.assertHTMLEqual("<div>\n<p>Hello</p></div>", "<div><p>Hello</p></div>\n")
  562. self.assertHTMLEqual(
  563. "<div><p>Hello\nWorld !</p></div>", "<div><p>Hello World\n!</p></div>"
  564. )
  565. self.assertHTMLEqual(
  566. "<div><p>Hello\nWorld !</p></div>", "<div><p>Hello World\n!</p></div>"
  567. )
  568. self.assertHTMLEqual("<p>Hello World !</p>", "<p>Hello World\n\n!</p>")
  569. self.assertHTMLEqual("<p> </p>", "<p></p>")
  570. self.assertHTMLEqual("<p/>", "<p></p>")
  571. self.assertHTMLEqual("<p />", "<p></p>")
  572. self.assertHTMLEqual("<input checked>", '<input checked="checked">')
  573. self.assertHTMLEqual("<p>Hello", "<p> Hello")
  574. self.assertHTMLEqual("<p>Hello</p>World", "<p>Hello</p> World")
  575. def test_ignore_comments(self):
  576. self.assertHTMLEqual(
  577. "<div>Hello<!-- this is a comment --> World!</div>",
  578. "<div>Hello World!</div>",
  579. )
  580. def test_unequal_html(self):
  581. self.assertHTMLNotEqual("<p>Hello</p>", "<p>Hello!</p>")
  582. self.assertHTMLNotEqual("<p>foo&#20;bar</p>", "<p>foo&nbsp;bar</p>")
  583. self.assertHTMLNotEqual("<p>foo bar</p>", "<p>foo &nbsp;bar</p>")
  584. self.assertHTMLNotEqual("<p>foo nbsp</p>", "<p>foo &nbsp;</p>")
  585. self.assertHTMLNotEqual("<p>foo #20</p>", "<p>foo &#20;</p>")
  586. self.assertHTMLNotEqual(
  587. "<p><span>Hello</span><span>World</span></p>",
  588. "<p><span>Hello</span>World</p>",
  589. )
  590. self.assertHTMLNotEqual(
  591. "<p><span>Hello</span>World</p>",
  592. "<p><span>Hello</span><span>World</span></p>",
  593. )
  594. def test_attributes(self):
  595. self.assertHTMLEqual(
  596. '<input type="text" id="id_name" />', '<input id="id_name" type="text" />'
  597. )
  598. self.assertHTMLEqual(
  599. """<input type='text' id="id_name" />""",
  600. '<input id="id_name" type="text" />',
  601. )
  602. self.assertHTMLNotEqual(
  603. '<input type="text" id="id_name" />',
  604. '<input type="password" id="id_name" />',
  605. )
  606. def test_class_attribute(self):
  607. pairs = [
  608. ('<p class="foo bar"></p>', '<p class="bar foo"></p>'),
  609. ('<p class=" foo bar "></p>', '<p class="bar foo"></p>'),
  610. ('<p class=" foo bar "></p>', '<p class="bar foo"></p>'),
  611. ('<p class="foo\tbar"></p>', '<p class="bar foo"></p>'),
  612. ('<p class="\tfoo\tbar\t"></p>', '<p class="bar foo"></p>'),
  613. ('<p class="\t\t\tfoo\t\t\tbar\t\t\t"></p>', '<p class="bar foo"></p>'),
  614. ('<p class="\t \nfoo \t\nbar\n\t "></p>', '<p class="bar foo"></p>'),
  615. ]
  616. for html1, html2 in pairs:
  617. with self.subTest(html1):
  618. self.assertHTMLEqual(html1, html2)
  619. def test_boolean_attribute(self):
  620. html1 = "<input checked>"
  621. html2 = '<input checked="">'
  622. html3 = '<input checked="checked">'
  623. self.assertHTMLEqual(html1, html2)
  624. self.assertHTMLEqual(html1, html3)
  625. self.assertHTMLEqual(html2, html3)
  626. self.assertHTMLNotEqual(html1, '<input checked="invalid">')
  627. self.assertEqual(str(parse_html(html1)), "<input checked>")
  628. self.assertEqual(str(parse_html(html2)), "<input checked>")
  629. self.assertEqual(str(parse_html(html3)), "<input checked>")
  630. def test_non_boolean_attibutes(self):
  631. html1 = "<input value>"
  632. html2 = '<input value="">'
  633. html3 = '<input value="value">'
  634. self.assertHTMLEqual(html1, html2)
  635. self.assertHTMLNotEqual(html1, html3)
  636. self.assertEqual(str(parse_html(html1)), '<input value="">')
  637. self.assertEqual(str(parse_html(html2)), '<input value="">')
  638. def test_normalize_refs(self):
  639. pairs = [
  640. ("&#39;", "&#x27;"),
  641. ("&#39;", "'"),
  642. ("&#x27;", "&#39;"),
  643. ("&#x27;", "'"),
  644. ("'", "&#39;"),
  645. ("'", "&#x27;"),
  646. ("&amp;", "&#38;"),
  647. ("&amp;", "&#x26;"),
  648. ("&amp;", "&"),
  649. ("&#38;", "&amp;"),
  650. ("&#38;", "&#x26;"),
  651. ("&#38;", "&"),
  652. ("&#x26;", "&amp;"),
  653. ("&#x26;", "&#38;"),
  654. ("&#x26;", "&"),
  655. ("&", "&amp;"),
  656. ("&", "&#38;"),
  657. ("&", "&#x26;"),
  658. ]
  659. for pair in pairs:
  660. with self.subTest(repr(pair)):
  661. self.assertHTMLEqual(*pair)
  662. def test_complex_examples(self):
  663. self.assertHTMLEqual(
  664. """<tr><th><label for="id_first_name">First name:</label></th>
  665. <td><input type="text" name="first_name" value="John" id="id_first_name" /></td></tr>
  666. <tr><th><label for="id_last_name">Last name:</label></th>
  667. <td><input type="text" id="id_last_name" name="last_name" value="Lennon" /></td></tr>
  668. <tr><th><label for="id_birthday">Birthday:</label></th>
  669. <td><input type="text" value="1940-10-9" name="birthday" id="id_birthday" /></td></tr>""", # NOQA
  670. """
  671. <tr><th>
  672. <label for="id_first_name">First name:</label></th><td>
  673. <input type="text" name="first_name" value="John" id="id_first_name" />
  674. </td></tr>
  675. <tr><th>
  676. <label for="id_last_name">Last name:</label></th><td>
  677. <input type="text" name="last_name" value="Lennon" id="id_last_name" />
  678. </td></tr>
  679. <tr><th>
  680. <label for="id_birthday">Birthday:</label></th><td>
  681. <input type="text" name="birthday" value="1940-10-9" id="id_birthday" />
  682. </td></tr>
  683. """,
  684. )
  685. self.assertHTMLEqual(
  686. """<!DOCTYPE html>
  687. <html>
  688. <head>
  689. <link rel="stylesheet">
  690. <title>Document</title>
  691. <meta attribute="value">
  692. </head>
  693. <body>
  694. <p>
  695. This is a valid paragraph
  696. <div> this is a div AFTER the p</div>
  697. </body>
  698. </html>""",
  699. """
  700. <html>
  701. <head>
  702. <link rel="stylesheet">
  703. <title>Document</title>
  704. <meta attribute="value">
  705. </head>
  706. <body>
  707. <p> This is a valid paragraph
  708. <!-- browsers would close the p tag here -->
  709. <div> this is a div AFTER the p</div>
  710. </p> <!-- this is invalid HTML parsing, but it should make no
  711. difference in most cases -->
  712. </body>
  713. </html>""",
  714. )
  715. def test_html_contain(self):
  716. # equal html contains each other
  717. dom1 = parse_html("<p>foo")
  718. dom2 = parse_html("<p>foo</p>")
  719. self.assertIn(dom1, dom2)
  720. self.assertIn(dom2, dom1)
  721. dom2 = parse_html("<div><p>foo</p></div>")
  722. self.assertIn(dom1, dom2)
  723. self.assertNotIn(dom2, dom1)
  724. self.assertNotIn("<p>foo</p>", dom2)
  725. self.assertIn("foo", dom2)
  726. # when a root element is used ...
  727. dom1 = parse_html("<p>foo</p><p>bar</p>")
  728. dom2 = parse_html("<p>foo</p><p>bar</p>")
  729. self.assertIn(dom1, dom2)
  730. dom1 = parse_html("<p>foo</p>")
  731. self.assertIn(dom1, dom2)
  732. dom1 = parse_html("<p>bar</p>")
  733. self.assertIn(dom1, dom2)
  734. dom1 = parse_html("<div><p>foo</p><p>bar</p></div>")
  735. self.assertIn(dom2, dom1)
  736. def test_count(self):
  737. # equal html contains each other one time
  738. dom1 = parse_html("<p>foo")
  739. dom2 = parse_html("<p>foo</p>")
  740. self.assertEqual(dom1.count(dom2), 1)
  741. self.assertEqual(dom2.count(dom1), 1)
  742. dom2 = parse_html("<p>foo</p><p>bar</p>")
  743. self.assertEqual(dom2.count(dom1), 1)
  744. dom2 = parse_html("<p>foo foo</p><p>foo</p>")
  745. self.assertEqual(dom2.count("foo"), 3)
  746. dom2 = parse_html('<p class="bar">foo</p>')
  747. self.assertEqual(dom2.count("bar"), 0)
  748. self.assertEqual(dom2.count("class"), 0)
  749. self.assertEqual(dom2.count("p"), 0)
  750. self.assertEqual(dom2.count("o"), 2)
  751. dom2 = parse_html("<p>foo</p><p>foo</p>")
  752. self.assertEqual(dom2.count(dom1), 2)
  753. dom2 = parse_html('<div><p>foo<input type=""></p><p>foo</p></div>')
  754. self.assertEqual(dom2.count(dom1), 1)
  755. dom2 = parse_html("<div><div><p>foo</p></div></div>")
  756. self.assertEqual(dom2.count(dom1), 1)
  757. dom2 = parse_html("<p>foo<p>foo</p></p>")
  758. self.assertEqual(dom2.count(dom1), 1)
  759. dom2 = parse_html("<p>foo<p>bar</p></p>")
  760. self.assertEqual(dom2.count(dom1), 0)
  761. # HTML with a root element contains the same HTML with no root element.
  762. dom1 = parse_html("<p>foo</p><p>bar</p>")
  763. dom2 = parse_html("<div><p>foo</p><p>bar</p></div>")
  764. self.assertEqual(dom2.count(dom1), 1)
  765. # Target of search is a sequence of child elements and appears more
  766. # than once.
  767. dom2 = parse_html("<div><p>foo</p><p>bar</p><p>foo</p><p>bar</p></div>")
  768. self.assertEqual(dom2.count(dom1), 2)
  769. # Searched HTML has additional children.
  770. dom1 = parse_html("<a/><b/>")
  771. dom2 = parse_html("<a/><b/><c/>")
  772. self.assertEqual(dom2.count(dom1), 1)
  773. # No match found in children.
  774. dom1 = parse_html("<b/><a/>")
  775. self.assertEqual(dom2.count(dom1), 0)
  776. # Target of search found among children and grandchildren.
  777. dom1 = parse_html("<b/><b/>")
  778. dom2 = parse_html("<a><b/><b/></a><b/><b/>")
  779. self.assertEqual(dom2.count(dom1), 2)
  780. def test_root_element_escaped_html(self):
  781. html = "&lt;br&gt;"
  782. parsed = parse_html(html)
  783. self.assertEqual(str(parsed), html)
  784. def test_parsing_errors(self):
  785. with self.assertRaises(AssertionError):
  786. self.assertHTMLEqual("<p>", "")
  787. with self.assertRaises(AssertionError):
  788. self.assertHTMLEqual("", "<p>")
  789. error_msg = (
  790. "First argument is not valid HTML:\n"
  791. "('Unexpected end tag `div` (Line 1, Column 6)', (1, 6))"
  792. )
  793. with self.assertRaisesMessage(AssertionError, error_msg):
  794. self.assertHTMLEqual("< div></ div>", "<div></div>")
  795. with self.assertRaises(HTMLParseError):
  796. parse_html("</p>")
  797. def test_escaped_html_errors(self):
  798. msg = "<p>\n<foo>\n</p> != <p>\n&lt;foo&gt;\n</p>\n"
  799. with self.assertRaisesMessage(AssertionError, msg):
  800. self.assertHTMLEqual("<p><foo></p>", "<p>&lt;foo&gt;</p>")
  801. with self.assertRaisesMessage(AssertionError, msg):
  802. self.assertHTMLEqual("<p><foo></p>", "<p>&#60;foo&#62;</p>")
  803. def test_contains_html(self):
  804. response = HttpResponse(
  805. """<body>
  806. This is a form: <form method="get">
  807. <input type="text" name="Hello" />
  808. </form></body>"""
  809. )
  810. self.assertNotContains(response, "<input name='Hello' type='text'>")
  811. self.assertContains(response, '<form method="get">')
  812. self.assertContains(response, "<input name='Hello' type='text'>", html=True)
  813. self.assertNotContains(response, '<form method="get">', html=True)
  814. invalid_response = HttpResponse("""<body <bad>>""")
  815. with self.assertRaises(AssertionError):
  816. self.assertContains(invalid_response, "<p></p>")
  817. with self.assertRaises(AssertionError):
  818. self.assertContains(response, '<p "whats" that>')
  819. def test_unicode_handling(self):
  820. response = HttpResponse(
  821. '<p class="help">Some help text for the title (with Unicode ŠĐĆŽćžšđ)</p>'
  822. )
  823. self.assertContains(
  824. response,
  825. '<p class="help">Some help text for the title (with Unicode ŠĐĆŽćžšđ)</p>',
  826. html=True,
  827. )
  828. class JSONEqualTests(SimpleTestCase):
  829. def test_simple_equal(self):
  830. json1 = '{"attr1": "foo", "attr2":"baz"}'
  831. json2 = '{"attr1": "foo", "attr2":"baz"}'
  832. self.assertJSONEqual(json1, json2)
  833. def test_simple_equal_unordered(self):
  834. json1 = '{"attr1": "foo", "attr2":"baz"}'
  835. json2 = '{"attr2":"baz", "attr1": "foo"}'
  836. self.assertJSONEqual(json1, json2)
  837. def test_simple_equal_raise(self):
  838. json1 = '{"attr1": "foo", "attr2":"baz"}'
  839. json2 = '{"attr2":"baz"}'
  840. with self.assertRaises(AssertionError):
  841. self.assertJSONEqual(json1, json2)
  842. def test_equal_parsing_errors(self):
  843. invalid_json = '{"attr1": "foo, "attr2":"baz"}'
  844. valid_json = '{"attr1": "foo", "attr2":"baz"}'
  845. with self.assertRaises(AssertionError):
  846. self.assertJSONEqual(invalid_json, valid_json)
  847. with self.assertRaises(AssertionError):
  848. self.assertJSONEqual(valid_json, invalid_json)
  849. def test_simple_not_equal(self):
  850. json1 = '{"attr1": "foo", "attr2":"baz"}'
  851. json2 = '{"attr2":"baz"}'
  852. self.assertJSONNotEqual(json1, json2)
  853. def test_simple_not_equal_raise(self):
  854. json1 = '{"attr1": "foo", "attr2":"baz"}'
  855. json2 = '{"attr1": "foo", "attr2":"baz"}'
  856. with self.assertRaises(AssertionError):
  857. self.assertJSONNotEqual(json1, json2)
  858. def test_not_equal_parsing_errors(self):
  859. invalid_json = '{"attr1": "foo, "attr2":"baz"}'
  860. valid_json = '{"attr1": "foo", "attr2":"baz"}'
  861. with self.assertRaises(AssertionError):
  862. self.assertJSONNotEqual(invalid_json, valid_json)
  863. with self.assertRaises(AssertionError):
  864. self.assertJSONNotEqual(valid_json, invalid_json)
  865. class XMLEqualTests(SimpleTestCase):
  866. def test_simple_equal(self):
  867. xml1 = "<elem attr1='a' attr2='b' />"
  868. xml2 = "<elem attr1='a' attr2='b' />"
  869. self.assertXMLEqual(xml1, xml2)
  870. def test_simple_equal_unordered(self):
  871. xml1 = "<elem attr1='a' attr2='b' />"
  872. xml2 = "<elem attr2='b' attr1='a' />"
  873. self.assertXMLEqual(xml1, xml2)
  874. def test_simple_equal_raise(self):
  875. xml1 = "<elem attr1='a' />"
  876. xml2 = "<elem attr2='b' attr1='a' />"
  877. with self.assertRaises(AssertionError):
  878. self.assertXMLEqual(xml1, xml2)
  879. def test_simple_equal_raises_message(self):
  880. xml1 = "<elem attr1='a' />"
  881. xml2 = "<elem attr2='b' attr1='a' />"
  882. msg = """{xml1} != {xml2}
  883. - <elem attr1='a' />
  884. + <elem attr2='b' attr1='a' />
  885. ? ++++++++++
  886. """.format(
  887. xml1=repr(xml1), xml2=repr(xml2)
  888. )
  889. with self.assertRaisesMessage(AssertionError, msg):
  890. self.assertXMLEqual(xml1, xml2)
  891. def test_simple_not_equal(self):
  892. xml1 = "<elem attr1='a' attr2='c' />"
  893. xml2 = "<elem attr1='a' attr2='b' />"
  894. self.assertXMLNotEqual(xml1, xml2)
  895. def test_simple_not_equal_raise(self):
  896. xml1 = "<elem attr1='a' attr2='b' />"
  897. xml2 = "<elem attr2='b' attr1='a' />"
  898. with self.assertRaises(AssertionError):
  899. self.assertXMLNotEqual(xml1, xml2)
  900. def test_parsing_errors(self):
  901. xml_unvalid = "<elem attr1='a attr2='b' />"
  902. xml2 = "<elem attr2='b' attr1='a' />"
  903. with self.assertRaises(AssertionError):
  904. self.assertXMLNotEqual(xml_unvalid, xml2)
  905. def test_comment_root(self):
  906. xml1 = "<?xml version='1.0'?><!-- comment1 --><elem attr1='a' attr2='b' />"
  907. xml2 = "<?xml version='1.0'?><!-- comment2 --><elem attr2='b' attr1='a' />"
  908. self.assertXMLEqual(xml1, xml2)
  909. def test_simple_equal_with_leading_or_trailing_whitespace(self):
  910. xml1 = "<elem>foo</elem> \t\n"
  911. xml2 = " \t\n<elem>foo</elem>"
  912. self.assertXMLEqual(xml1, xml2)
  913. def test_simple_not_equal_with_whitespace_in_the_middle(self):
  914. xml1 = "<elem>foo</elem><elem>bar</elem>"
  915. xml2 = "<elem>foo</elem> <elem>bar</elem>"
  916. self.assertXMLNotEqual(xml1, xml2)
  917. def test_doctype_root(self):
  918. xml1 = '<?xml version="1.0"?><!DOCTYPE root SYSTEM "example1.dtd"><root />'
  919. xml2 = '<?xml version="1.0"?><!DOCTYPE root SYSTEM "example2.dtd"><root />'
  920. self.assertXMLEqual(xml1, xml2)
  921. def test_processing_instruction(self):
  922. xml1 = (
  923. '<?xml version="1.0"?>'
  924. '<?xml-model href="http://www.example1.com"?><root />'
  925. )
  926. xml2 = (
  927. '<?xml version="1.0"?>'
  928. '<?xml-model href="http://www.example2.com"?><root />'
  929. )
  930. self.assertXMLEqual(xml1, xml2)
  931. self.assertXMLEqual(
  932. '<?xml-stylesheet href="style1.xslt" type="text/xsl"?><root />',
  933. '<?xml-stylesheet href="style2.xslt" type="text/xsl"?><root />',
  934. )
  935. class SkippingExtraTests(TestCase):
  936. fixtures = ["should_not_be_loaded.json"]
  937. # HACK: This depends on internals of our TestCase subclasses
  938. def __call__(self, result=None):
  939. # Detect fixture loading by counting SQL queries, should be zero
  940. with self.assertNumQueries(0):
  941. super().__call__(result)
  942. @unittest.skip("Fixture loading should not be performed for skipped tests.")
  943. def test_fixtures_are_skipped(self):
  944. pass
  945. class AssertRaisesMsgTest(SimpleTestCase):
  946. def test_assert_raises_message(self):
  947. msg = "'Expected message' not found in 'Unexpected message'"
  948. # context manager form of assertRaisesMessage()
  949. with self.assertRaisesMessage(AssertionError, msg):
  950. with self.assertRaisesMessage(ValueError, "Expected message"):
  951. raise ValueError("Unexpected message")
  952. # callable form
  953. def func():
  954. raise ValueError("Unexpected message")
  955. with self.assertRaisesMessage(AssertionError, msg):
  956. self.assertRaisesMessage(ValueError, "Expected message", func)
  957. def test_special_re_chars(self):
  958. """assertRaisesMessage shouldn't interpret RE special chars."""
  959. def func1():
  960. raise ValueError("[.*x+]y?")
  961. with self.assertRaisesMessage(ValueError, "[.*x+]y?"):
  962. func1()
  963. class AssertWarnsMessageTests(SimpleTestCase):
  964. def test_context_manager(self):
  965. with self.assertWarnsMessage(UserWarning, "Expected message"):
  966. warnings.warn("Expected message", UserWarning)
  967. def test_context_manager_failure(self):
  968. msg = "Expected message' not found in 'Unexpected message'"
  969. with self.assertRaisesMessage(AssertionError, msg):
  970. with self.assertWarnsMessage(UserWarning, "Expected message"):
  971. warnings.warn("Unexpected message", UserWarning)
  972. def test_callable(self):
  973. def func():
  974. warnings.warn("Expected message", UserWarning)
  975. self.assertWarnsMessage(UserWarning, "Expected message", func)
  976. def test_special_re_chars(self):
  977. def func1():
  978. warnings.warn("[.*x+]y?", UserWarning)
  979. with self.assertWarnsMessage(UserWarning, "[.*x+]y?"):
  980. func1()
  981. class AssertFieldOutputTests(SimpleTestCase):
  982. def test_assert_field_output(self):
  983. error_invalid = ["Enter a valid email address."]
  984. self.assertFieldOutput(
  985. EmailField, {"a@a.com": "a@a.com"}, {"aaa": error_invalid}
  986. )
  987. with self.assertRaises(AssertionError):
  988. self.assertFieldOutput(
  989. EmailField,
  990. {"a@a.com": "a@a.com"},
  991. {"aaa": error_invalid + ["Another error"]},
  992. )
  993. with self.assertRaises(AssertionError):
  994. self.assertFieldOutput(
  995. EmailField, {"a@a.com": "Wrong output"}, {"aaa": error_invalid}
  996. )
  997. with self.assertRaises(AssertionError):
  998. self.assertFieldOutput(
  999. EmailField,
  1000. {"a@a.com": "a@a.com"},
  1001. {"aaa": ["Come on, gimme some well formatted data, dude."]},
  1002. )
  1003. def test_custom_required_message(self):
  1004. class MyCustomField(IntegerField):
  1005. default_error_messages = {
  1006. "required": "This is really required.",
  1007. }
  1008. self.assertFieldOutput(MyCustomField, {}, {}, empty_value=None)
  1009. @override_settings(ROOT_URLCONF="test_utils.urls")
  1010. class AssertURLEqualTests(SimpleTestCase):
  1011. def test_equal(self):
  1012. valid_tests = (
  1013. ("http://example.com/?", "http://example.com/"),
  1014. ("http://example.com/?x=1&", "http://example.com/?x=1"),
  1015. ("http://example.com/?x=1&y=2", "http://example.com/?y=2&x=1"),
  1016. ("http://example.com/?x=1&y=2", "http://example.com/?y=2&x=1"),
  1017. (
  1018. "http://example.com/?x=1&y=2&a=1&a=2",
  1019. "http://example.com/?a=1&a=2&y=2&x=1",
  1020. ),
  1021. ("/path/to/?x=1&y=2&z=3", "/path/to/?z=3&y=2&x=1"),
  1022. ("?x=1&y=2&z=3", "?z=3&y=2&x=1"),
  1023. ("/test_utils/no_template_used/", reverse_lazy("no_template_used")),
  1024. )
  1025. for url1, url2 in valid_tests:
  1026. with self.subTest(url=url1):
  1027. self.assertURLEqual(url1, url2)
  1028. def test_not_equal(self):
  1029. invalid_tests = (
  1030. # Protocol must be the same.
  1031. ("http://example.com/", "https://example.com/"),
  1032. ("http://example.com/?x=1&x=2", "https://example.com/?x=2&x=1"),
  1033. ("http://example.com/?x=1&y=bar&x=2", "https://example.com/?y=bar&x=2&x=1"),
  1034. # Parameters of the same name must be in the same order.
  1035. ("/path/to?a=1&a=2", "/path/to/?a=2&a=1"),
  1036. )
  1037. for url1, url2 in invalid_tests:
  1038. with self.subTest(url=url1), self.assertRaises(AssertionError):
  1039. self.assertURLEqual(url1, url2)
  1040. def test_message(self):
  1041. msg = (
  1042. "Expected 'http://example.com/?x=1&x=2' to equal "
  1043. "'https://example.com/?x=2&x=1'"
  1044. )
  1045. with self.assertRaisesMessage(AssertionError, msg):
  1046. self.assertURLEqual(
  1047. "http://example.com/?x=1&x=2", "https://example.com/?x=2&x=1"
  1048. )
  1049. def test_msg_prefix(self):
  1050. msg = (
  1051. "Prefix: Expected 'http://example.com/?x=1&x=2' to equal "
  1052. "'https://example.com/?x=2&x=1'"
  1053. )
  1054. with self.assertRaisesMessage(AssertionError, msg):
  1055. self.assertURLEqual(
  1056. "http://example.com/?x=1&x=2",
  1057. "https://example.com/?x=2&x=1",
  1058. msg_prefix="Prefix: ",
  1059. )
  1060. class TestForm(Form):
  1061. field = CharField()
  1062. def clean_field(self):
  1063. value = self.cleaned_data.get("field", "")
  1064. if value == "invalid":
  1065. raise ValidationError("invalid value")
  1066. return value
  1067. def clean(self):
  1068. if self.cleaned_data.get("field") == "invalid_non_field":
  1069. raise ValidationError("non-field error")
  1070. return self.cleaned_data
  1071. @classmethod
  1072. def _get_cleaned_form(cls, field_value):
  1073. form = cls({"field": field_value})
  1074. form.full_clean()
  1075. return form
  1076. @classmethod
  1077. def valid(cls):
  1078. return cls._get_cleaned_form("valid")
  1079. @classmethod
  1080. def invalid(cls, nonfield=False):
  1081. return cls._get_cleaned_form("invalid_non_field" if nonfield else "invalid")
  1082. class TestFormset(formset_factory(TestForm)):
  1083. @classmethod
  1084. def _get_cleaned_formset(cls, field_value):
  1085. formset = cls(
  1086. {
  1087. "form-TOTAL_FORMS": "1",
  1088. "form-INITIAL_FORMS": "0",
  1089. "form-0-field": field_value,
  1090. }
  1091. )
  1092. formset.full_clean()
  1093. return formset
  1094. @classmethod
  1095. def valid(cls):
  1096. return cls._get_cleaned_formset("valid")
  1097. @classmethod
  1098. def invalid(cls, nonfield=False, nonform=False):
  1099. if nonform:
  1100. formset = cls({}, error_messages={"missing_management_form": "error"})
  1101. formset.full_clean()
  1102. return formset
  1103. return cls._get_cleaned_formset("invalid_non_field" if nonfield else "invalid")
  1104. class AssertFormErrorTests(SimpleTestCase):
  1105. def test_single_error(self):
  1106. self.assertFormError(TestForm.invalid(), "field", "invalid value")
  1107. def test_error_list(self):
  1108. self.assertFormError(TestForm.invalid(), "field", ["invalid value"])
  1109. def test_empty_errors_valid_form(self):
  1110. self.assertFormError(TestForm.valid(), "field", [])
  1111. def test_empty_errors_valid_form_non_field_errors(self):
  1112. self.assertFormError(TestForm.valid(), None, [])
  1113. def test_field_not_in_form(self):
  1114. msg = (
  1115. "The form <TestForm bound=True, valid=False, fields=(field)> does not "
  1116. "contain the field 'other_field'."
  1117. )
  1118. with self.assertRaisesMessage(AssertionError, msg):
  1119. self.assertFormError(TestForm.invalid(), "other_field", "invalid value")
  1120. msg_prefix = "Custom prefix"
  1121. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1122. self.assertFormError(
  1123. TestForm.invalid(),
  1124. "other_field",
  1125. "invalid value",
  1126. msg_prefix=msg_prefix,
  1127. )
  1128. def test_field_with_no_errors(self):
  1129. msg = (
  1130. "The errors of field 'field' on form <TestForm bound=True, valid=True, "
  1131. "fields=(field)> don't match."
  1132. )
  1133. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1134. self.assertFormError(TestForm.valid(), "field", "invalid value")
  1135. self.assertIn("[] != ['invalid value']", str(ctx.exception))
  1136. msg_prefix = "Custom prefix"
  1137. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1138. self.assertFormError(
  1139. TestForm.valid(), "field", "invalid value", msg_prefix=msg_prefix
  1140. )
  1141. def test_field_with_different_error(self):
  1142. msg = (
  1143. "The errors of field 'field' on form <TestForm bound=True, valid=False, "
  1144. "fields=(field)> don't match."
  1145. )
  1146. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1147. self.assertFormError(TestForm.invalid(), "field", "other error")
  1148. self.assertIn("['invalid value'] != ['other error']", str(ctx.exception))
  1149. msg_prefix = "Custom prefix"
  1150. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1151. self.assertFormError(
  1152. TestForm.invalid(), "field", "other error", msg_prefix=msg_prefix
  1153. )
  1154. def test_unbound_form(self):
  1155. msg = (
  1156. "The form <TestForm bound=False, valid=Unknown, fields=(field)> is not "
  1157. "bound, it will never have any errors."
  1158. )
  1159. with self.assertRaisesMessage(AssertionError, msg):
  1160. self.assertFormError(TestForm(), "field", [])
  1161. msg_prefix = "Custom prefix"
  1162. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1163. self.assertFormError(TestForm(), "field", [], msg_prefix=msg_prefix)
  1164. def test_empty_errors_invalid_form(self):
  1165. msg = (
  1166. "The errors of field 'field' on form <TestForm bound=True, valid=False, "
  1167. "fields=(field)> don't match."
  1168. )
  1169. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1170. self.assertFormError(TestForm.invalid(), "field", [])
  1171. self.assertIn("['invalid value'] != []", str(ctx.exception))
  1172. def test_non_field_errors(self):
  1173. self.assertFormError(TestForm.invalid(nonfield=True), None, "non-field error")
  1174. def test_different_non_field_errors(self):
  1175. msg = (
  1176. "The non-field errors of form <TestForm bound=True, valid=False, "
  1177. "fields=(field)> don't match."
  1178. )
  1179. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1180. self.assertFormError(
  1181. TestForm.invalid(nonfield=True), None, "other non-field error"
  1182. )
  1183. self.assertIn(
  1184. "['non-field error'] != ['other non-field error']", str(ctx.exception)
  1185. )
  1186. msg_prefix = "Custom prefix"
  1187. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1188. self.assertFormError(
  1189. TestForm.invalid(nonfield=True),
  1190. None,
  1191. "other non-field error",
  1192. msg_prefix=msg_prefix,
  1193. )
  1194. class AssertFormSetErrorTests(SimpleTestCase):
  1195. def test_single_error(self):
  1196. self.assertFormSetError(TestFormset.invalid(), 0, "field", "invalid value")
  1197. def test_error_list(self):
  1198. self.assertFormSetError(TestFormset.invalid(), 0, "field", ["invalid value"])
  1199. def test_empty_errors_valid_formset(self):
  1200. self.assertFormSetError(TestFormset.valid(), 0, "field", [])
  1201. def test_multiple_forms(self):
  1202. formset = TestFormset(
  1203. {
  1204. "form-TOTAL_FORMS": "2",
  1205. "form-INITIAL_FORMS": "0",
  1206. "form-0-field": "valid",
  1207. "form-1-field": "invalid",
  1208. }
  1209. )
  1210. formset.full_clean()
  1211. self.assertFormSetError(formset, 0, "field", [])
  1212. self.assertFormSetError(formset, 1, "field", ["invalid value"])
  1213. def test_field_not_in_form(self):
  1214. msg = (
  1215. "The form 0 of formset <TestFormset: bound=True valid=False total_forms=1> "
  1216. "does not contain the field 'other_field'."
  1217. )
  1218. with self.assertRaisesMessage(AssertionError, msg):
  1219. self.assertFormSetError(
  1220. TestFormset.invalid(), 0, "other_field", "invalid value"
  1221. )
  1222. msg_prefix = "Custom prefix"
  1223. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1224. self.assertFormSetError(
  1225. TestFormset.invalid(),
  1226. 0,
  1227. "other_field",
  1228. "invalid value",
  1229. msg_prefix=msg_prefix,
  1230. )
  1231. def test_field_with_no_errors(self):
  1232. msg = (
  1233. "The errors of field 'field' on form 0 of formset <TestFormset: bound=True "
  1234. "valid=True total_forms=1> don't match."
  1235. )
  1236. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1237. self.assertFormSetError(TestFormset.valid(), 0, "field", "invalid value")
  1238. self.assertIn("[] != ['invalid value']", str(ctx.exception))
  1239. msg_prefix = "Custom prefix"
  1240. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1241. self.assertFormSetError(
  1242. TestFormset.valid(), 0, "field", "invalid value", msg_prefix=msg_prefix
  1243. )
  1244. def test_field_with_different_error(self):
  1245. msg = (
  1246. "The errors of field 'field' on form 0 of formset <TestFormset: bound=True "
  1247. "valid=False total_forms=1> don't match."
  1248. )
  1249. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1250. self.assertFormSetError(TestFormset.invalid(), 0, "field", "other error")
  1251. self.assertIn("['invalid value'] != ['other error']", str(ctx.exception))
  1252. msg_prefix = "Custom prefix"
  1253. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1254. self.assertFormSetError(
  1255. TestFormset.invalid(), 0, "field", "other error", msg_prefix=msg_prefix
  1256. )
  1257. def test_unbound_formset(self):
  1258. msg = (
  1259. "The formset <TestFormset: bound=False valid=Unknown total_forms=1> is not "
  1260. "bound, it will never have any errors."
  1261. )
  1262. with self.assertRaisesMessage(AssertionError, msg):
  1263. self.assertFormSetError(TestFormset(), 0, "field", [])
  1264. def test_empty_errors_invalid_formset(self):
  1265. msg = (
  1266. "The errors of field 'field' on form 0 of formset <TestFormset: bound=True "
  1267. "valid=False total_forms=1> don't match."
  1268. )
  1269. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1270. self.assertFormSetError(TestFormset.invalid(), 0, "field", [])
  1271. self.assertIn("['invalid value'] != []", str(ctx.exception))
  1272. def test_non_field_errors(self):
  1273. self.assertFormSetError(
  1274. TestFormset.invalid(nonfield=True), 0, None, "non-field error"
  1275. )
  1276. def test_different_non_field_errors(self):
  1277. msg = (
  1278. "The non-field errors of form 0 of formset <TestFormset: bound=True "
  1279. "valid=False total_forms=1> don't match."
  1280. )
  1281. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1282. self.assertFormSetError(
  1283. TestFormset.invalid(nonfield=True), 0, None, "other non-field error"
  1284. )
  1285. self.assertIn(
  1286. "['non-field error'] != ['other non-field error']", str(ctx.exception)
  1287. )
  1288. msg_prefix = "Custom prefix"
  1289. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1290. self.assertFormSetError(
  1291. TestFormset.invalid(nonfield=True),
  1292. 0,
  1293. None,
  1294. "other non-field error",
  1295. msg_prefix=msg_prefix,
  1296. )
  1297. def test_no_non_field_errors(self):
  1298. msg = (
  1299. "The non-field errors of form 0 of formset <TestFormset: bound=True "
  1300. "valid=False total_forms=1> don't match."
  1301. )
  1302. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1303. self.assertFormSetError(TestFormset.invalid(), 0, None, "non-field error")
  1304. self.assertIn("[] != ['non-field error']", str(ctx.exception))
  1305. msg_prefix = "Custom prefix"
  1306. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1307. self.assertFormSetError(
  1308. TestFormset.invalid(), 0, None, "non-field error", msg_prefix=msg_prefix
  1309. )
  1310. def test_non_form_errors(self):
  1311. self.assertFormSetError(TestFormset.invalid(nonform=True), None, None, "error")
  1312. def test_different_non_form_errors(self):
  1313. msg = (
  1314. "The non-form errors of formset <TestFormset: bound=True valid=False "
  1315. "total_forms=0> don't match."
  1316. )
  1317. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1318. self.assertFormSetError(
  1319. TestFormset.invalid(nonform=True), None, None, "other error"
  1320. )
  1321. self.assertIn("['error'] != ['other error']", str(ctx.exception))
  1322. msg_prefix = "Custom prefix"
  1323. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1324. self.assertFormSetError(
  1325. TestFormset.invalid(nonform=True),
  1326. None,
  1327. None,
  1328. "other error",
  1329. msg_prefix=msg_prefix,
  1330. )
  1331. def test_no_non_form_errors(self):
  1332. msg = (
  1333. "The non-form errors of formset <TestFormset: bound=True valid=False "
  1334. "total_forms=1> don't match."
  1335. )
  1336. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1337. self.assertFormSetError(TestFormset.invalid(), None, None, "error")
  1338. self.assertIn("[] != ['error']", str(ctx.exception))
  1339. msg_prefix = "Custom prefix"
  1340. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1341. self.assertFormSetError(
  1342. TestFormset.invalid(),
  1343. None,
  1344. None,
  1345. "error",
  1346. msg_prefix=msg_prefix,
  1347. )
  1348. def test_non_form_errors_with_field(self):
  1349. msg = "You must use field=None with form_index=None."
  1350. with self.assertRaisesMessage(ValueError, msg):
  1351. self.assertFormSetError(
  1352. TestFormset.invalid(nonform=True), None, "field", "error"
  1353. )
  1354. def test_form_index_too_big(self):
  1355. msg = (
  1356. "The formset <TestFormset: bound=True valid=False total_forms=1> only has "
  1357. "1 form."
  1358. )
  1359. with self.assertRaisesMessage(AssertionError, msg):
  1360. self.assertFormSetError(TestFormset.invalid(), 2, "field", "error")
  1361. def test_form_index_too_big_plural(self):
  1362. formset = TestFormset(
  1363. {
  1364. "form-TOTAL_FORMS": "2",
  1365. "form-INITIAL_FORMS": "0",
  1366. "form-0-field": "valid",
  1367. "form-1-field": "valid",
  1368. }
  1369. )
  1370. formset.full_clean()
  1371. msg = (
  1372. "The formset <TestFormset: bound=True valid=True total_forms=2> only has 2 "
  1373. "forms."
  1374. )
  1375. with self.assertRaisesMessage(AssertionError, msg):
  1376. self.assertFormSetError(formset, 2, "field", "error")
  1377. class FirstUrls:
  1378. urlpatterns = [path("first/", empty_response, name="first")]
  1379. class SecondUrls:
  1380. urlpatterns = [path("second/", empty_response, name="second")]
  1381. class SetupTestEnvironmentTests(SimpleTestCase):
  1382. def test_setup_test_environment_calling_more_than_once(self):
  1383. with self.assertRaisesMessage(
  1384. RuntimeError, "setup_test_environment() was already called"
  1385. ):
  1386. setup_test_environment()
  1387. def test_allowed_hosts(self):
  1388. for type_ in (list, tuple):
  1389. with self.subTest(type_=type_):
  1390. allowed_hosts = type_("*")
  1391. with mock.patch("django.test.utils._TestState") as x:
  1392. del x.saved_data
  1393. with self.settings(ALLOWED_HOSTS=allowed_hosts):
  1394. setup_test_environment()
  1395. self.assertEqual(settings.ALLOWED_HOSTS, ["*", "testserver"])
  1396. class OverrideSettingsTests(SimpleTestCase):
  1397. # #21518 -- If neither override_settings nor a setting_changed receiver
  1398. # clears the URL cache between tests, then one of test_first or
  1399. # test_second will fail.
  1400. @override_settings(ROOT_URLCONF=FirstUrls)
  1401. def test_urlconf_first(self):
  1402. reverse("first")
  1403. @override_settings(ROOT_URLCONF=SecondUrls)
  1404. def test_urlconf_second(self):
  1405. reverse("second")
  1406. def test_urlconf_cache(self):
  1407. with self.assertRaises(NoReverseMatch):
  1408. reverse("first")
  1409. with self.assertRaises(NoReverseMatch):
  1410. reverse("second")
  1411. with override_settings(ROOT_URLCONF=FirstUrls):
  1412. self.client.get(reverse("first"))
  1413. with self.assertRaises(NoReverseMatch):
  1414. reverse("second")
  1415. with override_settings(ROOT_URLCONF=SecondUrls):
  1416. with self.assertRaises(NoReverseMatch):
  1417. reverse("first")
  1418. self.client.get(reverse("second"))
  1419. self.client.get(reverse("first"))
  1420. with self.assertRaises(NoReverseMatch):
  1421. reverse("second")
  1422. with self.assertRaises(NoReverseMatch):
  1423. reverse("first")
  1424. with self.assertRaises(NoReverseMatch):
  1425. reverse("second")
  1426. def test_override_media_root(self):
  1427. """
  1428. Overriding the MEDIA_ROOT setting should be reflected in the
  1429. base_location attribute of django.core.files.storage.default_storage.
  1430. """
  1431. self.assertEqual(default_storage.base_location, "")
  1432. with self.settings(MEDIA_ROOT="test_value"):
  1433. self.assertEqual(default_storage.base_location, "test_value")
  1434. def test_override_media_url(self):
  1435. """
  1436. Overriding the MEDIA_URL setting should be reflected in the
  1437. base_url attribute of django.core.files.storage.default_storage.
  1438. """
  1439. self.assertEqual(default_storage.base_location, "")
  1440. with self.settings(MEDIA_URL="/test_value/"):
  1441. self.assertEqual(default_storage.base_url, "/test_value/")
  1442. def test_override_file_upload_permissions(self):
  1443. """
  1444. Overriding the FILE_UPLOAD_PERMISSIONS setting should be reflected in
  1445. the file_permissions_mode attribute of
  1446. django.core.files.storage.default_storage.
  1447. """
  1448. self.assertEqual(default_storage.file_permissions_mode, 0o644)
  1449. with self.settings(FILE_UPLOAD_PERMISSIONS=0o777):
  1450. self.assertEqual(default_storage.file_permissions_mode, 0o777)
  1451. def test_override_file_upload_directory_permissions(self):
  1452. """
  1453. Overriding the FILE_UPLOAD_DIRECTORY_PERMISSIONS setting should be
  1454. reflected in the directory_permissions_mode attribute of
  1455. django.core.files.storage.default_storage.
  1456. """
  1457. self.assertIsNone(default_storage.directory_permissions_mode)
  1458. with self.settings(FILE_UPLOAD_DIRECTORY_PERMISSIONS=0o777):
  1459. self.assertEqual(default_storage.directory_permissions_mode, 0o777)
  1460. def test_override_database_routers(self):
  1461. """
  1462. Overriding DATABASE_ROUTERS should update the base router.
  1463. """
  1464. test_routers = [object()]
  1465. with self.settings(DATABASE_ROUTERS=test_routers):
  1466. self.assertEqual(router.routers, test_routers)
  1467. def test_override_static_url(self):
  1468. """
  1469. Overriding the STATIC_URL setting should be reflected in the
  1470. base_url attribute of
  1471. django.contrib.staticfiles.storage.staticfiles_storage.
  1472. """
  1473. with self.settings(STATIC_URL="/test/"):
  1474. self.assertEqual(staticfiles_storage.base_url, "/test/")
  1475. def test_override_static_root(self):
  1476. """
  1477. Overriding the STATIC_ROOT setting should be reflected in the
  1478. location attribute of
  1479. django.contrib.staticfiles.storage.staticfiles_storage.
  1480. """
  1481. with self.settings(STATIC_ROOT="/tmp/test"):
  1482. self.assertEqual(staticfiles_storage.location, os.path.abspath("/tmp/test"))
  1483. def test_override_staticfiles_storage(self):
  1484. """
  1485. Overriding the STORAGES setting should be reflected in
  1486. the value of django.contrib.staticfiles.storage.staticfiles_storage.
  1487. """
  1488. new_class = "ManifestStaticFilesStorage"
  1489. new_storage = "django.contrib.staticfiles.storage." + new_class
  1490. with self.settings(
  1491. STORAGES={STATICFILES_STORAGE_ALIAS: {"BACKEND": new_storage}}
  1492. ):
  1493. self.assertEqual(staticfiles_storage.__class__.__name__, new_class)
  1494. def test_override_staticfiles_finders(self):
  1495. """
  1496. Overriding the STATICFILES_FINDERS setting should be reflected in
  1497. the return value of django.contrib.staticfiles.finders.get_finders.
  1498. """
  1499. current = get_finders()
  1500. self.assertGreater(len(list(current)), 1)
  1501. finders = ["django.contrib.staticfiles.finders.FileSystemFinder"]
  1502. with self.settings(STATICFILES_FINDERS=finders):
  1503. self.assertEqual(len(list(get_finders())), len(finders))
  1504. def test_override_staticfiles_dirs(self):
  1505. """
  1506. Overriding the STATICFILES_DIRS setting should be reflected in
  1507. the locations attribute of the
  1508. django.contrib.staticfiles.finders.FileSystemFinder instance.
  1509. """
  1510. finder = get_finder("django.contrib.staticfiles.finders.FileSystemFinder")
  1511. test_path = "/tmp/test"
  1512. expected_location = ("", test_path)
  1513. self.assertNotIn(expected_location, finder.locations)
  1514. with self.settings(STATICFILES_DIRS=[test_path]):
  1515. finder = get_finder("django.contrib.staticfiles.finders.FileSystemFinder")
  1516. self.assertIn(expected_location, finder.locations)
  1517. @skipUnlessDBFeature("supports_transactions")
  1518. class TestBadSetUpTestData(TestCase):
  1519. """
  1520. An exception in setUpTestData() shouldn't leak a transaction which would
  1521. cascade across the rest of the test suite.
  1522. """
  1523. class MyException(Exception):
  1524. pass
  1525. @classmethod
  1526. def setUpClass(cls):
  1527. try:
  1528. super().setUpClass()
  1529. except cls.MyException:
  1530. cls._in_atomic_block = connection.in_atomic_block
  1531. @classmethod
  1532. def tearDownClass(Cls):
  1533. # override to avoid a second cls._rollback_atomics() which would fail.
  1534. # Normal setUpClass() methods won't have exception handling so this
  1535. # method wouldn't typically be run.
  1536. pass
  1537. @classmethod
  1538. def setUpTestData(cls):
  1539. # Simulate a broken setUpTestData() method.
  1540. raise cls.MyException()
  1541. def test_failure_in_setUpTestData_should_rollback_transaction(self):
  1542. # setUpTestData() should call _rollback_atomics() so that the
  1543. # transaction doesn't leak.
  1544. self.assertFalse(self._in_atomic_block)
  1545. @skipUnlessDBFeature("supports_transactions")
  1546. class CaptureOnCommitCallbacksTests(TestCase):
  1547. databases = {"default", "other"}
  1548. callback_called = False
  1549. def enqueue_callback(self, using="default"):
  1550. def hook():
  1551. self.callback_called = True
  1552. transaction.on_commit(hook, using=using)
  1553. def test_no_arguments(self):
  1554. with self.captureOnCommitCallbacks() as callbacks:
  1555. self.enqueue_callback()
  1556. self.assertEqual(len(callbacks), 1)
  1557. self.assertIs(self.callback_called, False)
  1558. callbacks[0]()
  1559. self.assertIs(self.callback_called, True)
  1560. def test_using(self):
  1561. with self.captureOnCommitCallbacks(using="other") as callbacks:
  1562. self.enqueue_callback(using="other")
  1563. self.assertEqual(len(callbacks), 1)
  1564. self.assertIs(self.callback_called, False)
  1565. callbacks[0]()
  1566. self.assertIs(self.callback_called, True)
  1567. def test_different_using(self):
  1568. with self.captureOnCommitCallbacks(using="default") as callbacks:
  1569. self.enqueue_callback(using="other")
  1570. self.assertEqual(callbacks, [])
  1571. def test_execute(self):
  1572. with self.captureOnCommitCallbacks(execute=True) as callbacks:
  1573. self.enqueue_callback()
  1574. self.assertEqual(len(callbacks), 1)
  1575. self.assertIs(self.callback_called, True)
  1576. def test_pre_callback(self):
  1577. def pre_hook():
  1578. pass
  1579. transaction.on_commit(pre_hook, using="default")
  1580. with self.captureOnCommitCallbacks() as callbacks:
  1581. self.enqueue_callback()
  1582. self.assertEqual(len(callbacks), 1)
  1583. self.assertNotEqual(callbacks[0], pre_hook)
  1584. def test_with_rolled_back_savepoint(self):
  1585. with self.captureOnCommitCallbacks() as callbacks:
  1586. try:
  1587. with transaction.atomic():
  1588. self.enqueue_callback()
  1589. raise IntegrityError
  1590. except IntegrityError:
  1591. # Inner transaction.atomic() has been rolled back.
  1592. pass
  1593. self.assertEqual(callbacks, [])
  1594. def test_execute_recursive(self):
  1595. with self.captureOnCommitCallbacks(execute=True) as callbacks:
  1596. transaction.on_commit(self.enqueue_callback)
  1597. self.assertEqual(len(callbacks), 2)
  1598. self.assertIs(self.callback_called, True)
  1599. def test_execute_tree(self):
  1600. """
  1601. A visualisation of the callback tree tested. Each node is expected to
  1602. be visited only once:
  1603. └─branch_1
  1604. ├─branch_2
  1605. │ ├─leaf_1
  1606. │ └─leaf_2
  1607. └─leaf_3
  1608. """
  1609. branch_1_call_counter = 0
  1610. branch_2_call_counter = 0
  1611. leaf_1_call_counter = 0
  1612. leaf_2_call_counter = 0
  1613. leaf_3_call_counter = 0
  1614. def leaf_1():
  1615. nonlocal leaf_1_call_counter
  1616. leaf_1_call_counter += 1
  1617. def leaf_2():
  1618. nonlocal leaf_2_call_counter
  1619. leaf_2_call_counter += 1
  1620. def leaf_3():
  1621. nonlocal leaf_3_call_counter
  1622. leaf_3_call_counter += 1
  1623. def branch_1():
  1624. nonlocal branch_1_call_counter
  1625. branch_1_call_counter += 1
  1626. transaction.on_commit(branch_2)
  1627. transaction.on_commit(leaf_3)
  1628. def branch_2():
  1629. nonlocal branch_2_call_counter
  1630. branch_2_call_counter += 1
  1631. transaction.on_commit(leaf_1)
  1632. transaction.on_commit(leaf_2)
  1633. with self.captureOnCommitCallbacks(execute=True) as callbacks:
  1634. transaction.on_commit(branch_1)
  1635. self.assertEqual(branch_1_call_counter, 1)
  1636. self.assertEqual(branch_2_call_counter, 1)
  1637. self.assertEqual(leaf_1_call_counter, 1)
  1638. self.assertEqual(leaf_2_call_counter, 1)
  1639. self.assertEqual(leaf_3_call_counter, 1)
  1640. self.assertEqual(callbacks, [branch_1, branch_2, leaf_3, leaf_1, leaf_2])
  1641. def test_execute_robust(self):
  1642. class MyException(Exception):
  1643. pass
  1644. def hook():
  1645. self.callback_called = True
  1646. raise MyException("robust callback")
  1647. with self.assertLogs("django.test", "ERROR") as cm:
  1648. with self.captureOnCommitCallbacks(execute=True) as callbacks:
  1649. transaction.on_commit(hook, robust=True)
  1650. self.assertEqual(len(callbacks), 1)
  1651. self.assertIs(self.callback_called, True)
  1652. log_record = cm.records[0]
  1653. self.assertEqual(
  1654. log_record.getMessage(),
  1655. "Error calling CaptureOnCommitCallbacksTests.test_execute_robust.<locals>."
  1656. "hook in on_commit() (robust callback).",
  1657. )
  1658. self.assertIsNotNone(log_record.exc_info)
  1659. raised_exception = log_record.exc_info[1]
  1660. self.assertIsInstance(raised_exception, MyException)
  1661. self.assertEqual(str(raised_exception), "robust callback")
  1662. class DisallowedDatabaseQueriesTests(SimpleTestCase):
  1663. def test_disallowed_database_connections(self):
  1664. expected_message = (
  1665. "Database connections to 'default' are not allowed in SimpleTestCase "
  1666. "subclasses. Either subclass TestCase or TransactionTestCase to "
  1667. "ensure proper test isolation or add 'default' to "
  1668. "test_utils.tests.DisallowedDatabaseQueriesTests.databases to "
  1669. "silence this failure."
  1670. )
  1671. with self.assertRaisesMessage(DatabaseOperationForbidden, expected_message):
  1672. connection.connect()
  1673. with self.assertRaisesMessage(DatabaseOperationForbidden, expected_message):
  1674. connection.temporary_connection()
  1675. def test_disallowed_database_queries(self):
  1676. expected_message = (
  1677. "Database queries to 'default' are not allowed in SimpleTestCase "
  1678. "subclasses. Either subclass TestCase or TransactionTestCase to "
  1679. "ensure proper test isolation or add 'default' to "
  1680. "test_utils.tests.DisallowedDatabaseQueriesTests.databases to "
  1681. "silence this failure."
  1682. )
  1683. with self.assertRaisesMessage(DatabaseOperationForbidden, expected_message):
  1684. Car.objects.first()
  1685. def test_disallowed_database_chunked_cursor_queries(self):
  1686. expected_message = (
  1687. "Database queries to 'default' are not allowed in SimpleTestCase "
  1688. "subclasses. Either subclass TestCase or TransactionTestCase to "
  1689. "ensure proper test isolation or add 'default' to "
  1690. "test_utils.tests.DisallowedDatabaseQueriesTests.databases to "
  1691. "silence this failure."
  1692. )
  1693. with self.assertRaisesMessage(DatabaseOperationForbidden, expected_message):
  1694. next(Car.objects.iterator())
  1695. class AllowedDatabaseQueriesTests(SimpleTestCase):
  1696. databases = {"default"}
  1697. def test_allowed_database_queries(self):
  1698. Car.objects.first()
  1699. def test_allowed_database_chunked_cursor_queries(self):
  1700. next(Car.objects.iterator(), None)
  1701. class DatabaseAliasTests(SimpleTestCase):
  1702. def setUp(self):
  1703. self.addCleanup(setattr, self.__class__, "databases", self.databases)
  1704. def test_no_close_match(self):
  1705. self.__class__.databases = {"void"}
  1706. message = (
  1707. "test_utils.tests.DatabaseAliasTests.databases refers to 'void' which is "
  1708. "not defined in settings.DATABASES."
  1709. )
  1710. with self.assertRaisesMessage(ImproperlyConfigured, message):
  1711. self._validate_databases()
  1712. def test_close_match(self):
  1713. self.__class__.databases = {"defualt"}
  1714. message = (
  1715. "test_utils.tests.DatabaseAliasTests.databases refers to 'defualt' which "
  1716. "is not defined in settings.DATABASES. Did you mean 'default'?"
  1717. )
  1718. with self.assertRaisesMessage(ImproperlyConfigured, message):
  1719. self._validate_databases()
  1720. def test_match(self):
  1721. self.__class__.databases = {"default", "other"}
  1722. self.assertEqual(self._validate_databases(), frozenset({"default", "other"}))
  1723. def test_all(self):
  1724. self.__class__.databases = "__all__"
  1725. self.assertEqual(self._validate_databases(), frozenset(connections))
  1726. @isolate_apps("test_utils", attr_name="class_apps")
  1727. class IsolatedAppsTests(SimpleTestCase):
  1728. def test_installed_apps(self):
  1729. self.assertEqual(
  1730. [app_config.label for app_config in self.class_apps.get_app_configs()],
  1731. ["test_utils"],
  1732. )
  1733. def test_class_decoration(self):
  1734. class ClassDecoration(models.Model):
  1735. pass
  1736. self.assertEqual(ClassDecoration._meta.apps, self.class_apps)
  1737. @isolate_apps("test_utils", kwarg_name="method_apps")
  1738. def test_method_decoration(self, method_apps):
  1739. class MethodDecoration(models.Model):
  1740. pass
  1741. self.assertEqual(MethodDecoration._meta.apps, method_apps)
  1742. def test_context_manager(self):
  1743. with isolate_apps("test_utils") as context_apps:
  1744. class ContextManager(models.Model):
  1745. pass
  1746. self.assertEqual(ContextManager._meta.apps, context_apps)
  1747. @isolate_apps("test_utils", kwarg_name="method_apps")
  1748. def test_nested(self, method_apps):
  1749. class MethodDecoration(models.Model):
  1750. pass
  1751. with isolate_apps("test_utils") as context_apps:
  1752. class ContextManager(models.Model):
  1753. pass
  1754. with isolate_apps("test_utils") as nested_context_apps:
  1755. class NestedContextManager(models.Model):
  1756. pass
  1757. self.assertEqual(MethodDecoration._meta.apps, method_apps)
  1758. self.assertEqual(ContextManager._meta.apps, context_apps)
  1759. self.assertEqual(NestedContextManager._meta.apps, nested_context_apps)
  1760. class DoNothingDecorator(TestContextDecorator):
  1761. def enable(self):
  1762. pass
  1763. def disable(self):
  1764. pass
  1765. class TestContextDecoratorTests(SimpleTestCase):
  1766. @mock.patch.object(DoNothingDecorator, "disable")
  1767. def test_exception_in_setup(self, mock_disable):
  1768. """An exception is setUp() is reraised after disable() is called."""
  1769. class ExceptionInSetUp(unittest.TestCase):
  1770. def setUp(self):
  1771. raise NotImplementedError("reraised")
  1772. decorator = DoNothingDecorator()
  1773. decorated_test_class = decorator.__call__(ExceptionInSetUp)()
  1774. self.assertFalse(mock_disable.called)
  1775. with self.assertRaisesMessage(NotImplementedError, "reraised"):
  1776. decorated_test_class.setUp()
  1777. decorated_test_class.doCleanups()
  1778. self.assertTrue(mock_disable.called)
  1779. def test_cleanups_run_after_tearDown(self):
  1780. calls = []
  1781. class SaveCallsDecorator(TestContextDecorator):
  1782. def enable(self):
  1783. calls.append("enable")
  1784. def disable(self):
  1785. calls.append("disable")
  1786. class AddCleanupInSetUp(unittest.TestCase):
  1787. def setUp(self):
  1788. calls.append("setUp")
  1789. self.addCleanup(lambda: calls.append("cleanup"))
  1790. decorator = SaveCallsDecorator()
  1791. decorated_test_class = decorator.__call__(AddCleanupInSetUp)()
  1792. decorated_test_class.setUp()
  1793. decorated_test_class.tearDown()
  1794. decorated_test_class.doCleanups()
  1795. self.assertEqual(calls, ["enable", "setUp", "cleanup", "disable"])