tests.py 82 KB


  1. import os
  2. import sys
  3. import unittest
  4. import warnings
  5. from io import StringIO
  6. from unittest import mock
  7. from django.conf import STATICFILES_STORAGE_ALIAS, settings
  8. from django.contrib.staticfiles.finders import get_finder, get_finders
  9. from django.contrib.staticfiles.storage import staticfiles_storage
  10. from django.core.exceptions import ImproperlyConfigured
  11. from django.core.files.storage import default_storage
  12. from django.db import (
  13. IntegrityError,
  14. connection,
  15. connections,
  16. models,
  17. router,
  18. transaction,
  19. )
  20. from django.forms import (
  21. CharField,
  22. EmailField,
  23. Form,
  24. IntegerField,
  25. ValidationError,
  26. formset_factory,
  27. )
  28. from django.http import HttpResponse
  29. from django.template import Context, Template
  30. from django.template.loader import render_to_string
  31. from django.test import (
  32. SimpleTestCase,
  33. TestCase,
  34. TransactionTestCase,
  35. skipIfDBFeature,
  36. skipUnlessDBFeature,
  37. )
  38. from django.test.html import HTMLParseError, parse_html
  39. from django.test.testcases import DatabaseOperationForbidden
  40. from django.test.utils import (
  41. CaptureQueriesContext,
  42. TestContextDecorator,
  43. isolate_apps,
  44. override_settings,
  45. setup_test_environment,
  46. )
  47. from django.urls import NoReverseMatch, path, reverse, reverse_lazy
  48. from django.utils.html import VOID_ELEMENTS
  49. from django.utils.version import PY311
  50. from .models import Car, Person, PossessedCar
  51. from .views import empty_response
  52. class SkippingTestCase(SimpleTestCase):
  53. def _assert_skipping(self, func, expected_exc, msg=None):
  54. try:
  55. if msg is not None:
  56. with self.assertRaisesMessage(expected_exc, msg):
  57. func()
  58. else:
  59. with self.assertRaises(expected_exc):
  60. func()
  61. except unittest.SkipTest:
  62. self.fail("%s should not result in a skipped test." % func.__name__)
  63. def test_skip_unless_db_feature(self):
  64. """
  65. Testing the django.test.skipUnlessDBFeature decorator.
  66. """
  67. # Total hack, but it works, just want an attribute that's always true.
  68. @skipUnlessDBFeature("__class__")
  69. def test_func():
  70. raise ValueError
  71. @skipUnlessDBFeature("notprovided")
  72. def test_func2():
  73. raise ValueError
  74. @skipUnlessDBFeature("__class__", "__class__")
  75. def test_func3():
  76. raise ValueError
  77. @skipUnlessDBFeature("__class__", "notprovided")
  78. def test_func4():
  79. raise ValueError
  80. self._assert_skipping(test_func, ValueError)
  81. self._assert_skipping(test_func2, unittest.SkipTest)
  82. self._assert_skipping(test_func3, ValueError)
  83. self._assert_skipping(test_func4, unittest.SkipTest)
  84. class SkipTestCase(SimpleTestCase):
  85. @skipUnlessDBFeature("missing")
  86. def test_foo(self):
  87. pass
  88. self._assert_skipping(
  89. SkipTestCase("test_foo").test_foo,
  90. ValueError,
  91. "skipUnlessDBFeature cannot be used on test_foo (test_utils.tests."
  92. "SkippingTestCase.test_skip_unless_db_feature.<locals>.SkipTestCase%s) "
  93. "as SkippingTestCase.test_skip_unless_db_feature.<locals>.SkipTestCase "
  94. "doesn't allow queries against the 'default' database."
  95. # Python 3.11 uses fully qualified test name in the output.
  96. % (".test_foo" if PY311 else ""),
  97. )
  98. def test_skip_if_db_feature(self):
  99. """
  100. Testing the django.test.skipIfDBFeature decorator.
  101. """
  102. @skipIfDBFeature("__class__")
  103. def test_func():
  104. raise ValueError
  105. @skipIfDBFeature("notprovided")
  106. def test_func2():
  107. raise ValueError
  108. @skipIfDBFeature("__class__", "__class__")
  109. def test_func3():
  110. raise ValueError
  111. @skipIfDBFeature("__class__", "notprovided")
  112. def test_func4():
  113. raise ValueError
  114. @skipIfDBFeature("notprovided", "notprovided")
  115. def test_func5():
  116. raise ValueError
  117. self._assert_skipping(test_func, unittest.SkipTest)
  118. self._assert_skipping(test_func2, ValueError)
  119. self._assert_skipping(test_func3, unittest.SkipTest)
  120. self._assert_skipping(test_func4, unittest.SkipTest)
  121. self._assert_skipping(test_func5, ValueError)
  122. class SkipTestCase(SimpleTestCase):
  123. @skipIfDBFeature("missing")
  124. def test_foo(self):
  125. pass
  126. self._assert_skipping(
  127. SkipTestCase("test_foo").test_foo,
  128. ValueError,
  129. "skipIfDBFeature cannot be used on test_foo (test_utils.tests."
  130. "SkippingTestCase.test_skip_if_db_feature.<locals>.SkipTestCase%s) "
  131. "as SkippingTestCase.test_skip_if_db_feature.<locals>.SkipTestCase "
  132. "doesn't allow queries against the 'default' database."
  133. # Python 3.11 uses fully qualified test name in the output.
  134. % (".test_foo" if PY311 else ""),
  135. )
  136. class SkippingClassTestCase(TestCase):
  137. def test_skip_class_unless_db_feature(self):
  138. @skipUnlessDBFeature("__class__")
  139. class NotSkippedTests(TestCase):
  140. def test_dummy(self):
  141. return
  142. @skipUnlessDBFeature("missing")
  143. @skipIfDBFeature("__class__")
  144. class SkippedTests(TestCase):
  145. def test_will_be_skipped(self):
  146. self.fail("We should never arrive here.")
  147. @skipIfDBFeature("__dict__")
  148. class SkippedTestsSubclass(SkippedTests):
  149. pass
  150. test_suite = unittest.TestSuite()
  151. test_suite.addTest(NotSkippedTests("test_dummy"))
  152. try:
  153. test_suite.addTest(SkippedTests("test_will_be_skipped"))
  154. test_suite.addTest(SkippedTestsSubclass("test_will_be_skipped"))
  155. except unittest.SkipTest:
  156. self.fail("SkipTest should not be raised here.")
  157. result = unittest.TextTestRunner(stream=StringIO()).run(test_suite)
  158. # PY312: Python 3.12.1+ no longer includes skipped tests in the number
  159. # of running tests.
  160. self.assertEqual(result.testsRun, 1 if sys.version_info >= (3, 12, 1) else 3)
  161. self.assertEqual(len(result.skipped), 2)
  162. self.assertEqual(result.skipped[0][1], "Database has feature(s) __class__")
  163. self.assertEqual(result.skipped[1][1], "Database has feature(s) __class__")
  164. def test_missing_default_databases(self):
  165. @skipIfDBFeature("missing")
  166. class MissingDatabases(SimpleTestCase):
  167. def test_assertion_error(self):
  168. pass
  169. suite = unittest.TestSuite()
  170. try:
  171. suite.addTest(MissingDatabases("test_assertion_error"))
  172. except unittest.SkipTest:
  173. self.fail("SkipTest should not be raised at this stage")
  174. runner = unittest.TextTestRunner(stream=StringIO())
  175. msg = (
  176. "skipIfDBFeature cannot be used on <class 'test_utils.tests."
  177. "SkippingClassTestCase.test_missing_default_databases.<locals>."
  178. "MissingDatabases'> as it doesn't allow queries against the "
  179. "'default' database."
  180. )
  181. with self.assertRaisesMessage(ValueError, msg):
  182. runner.run(suite)
  183. @override_settings(ROOT_URLCONF="test_utils.urls")
  184. class AssertNumQueriesTests(TestCase):
  185. def test_assert_num_queries(self):
  186. def test_func():
  187. raise ValueError
  188. with self.assertRaises(ValueError):
  189. self.assertNumQueries(2, test_func)
  190. def test_assert_num_queries_with_client(self):
  191. person = Person.objects.create(name="test")
  192. self.assertNumQueries(
  193. 1, self.client.get, "/test_utils/get_person/%s/" % person.pk
  194. )
  195. self.assertNumQueries(
  196. 1, self.client.get, "/test_utils/get_person/%s/" % person.pk
  197. )
  198. def test_func():
  199. self.client.get("/test_utils/get_person/%s/" % person.pk)
  200. self.client.get("/test_utils/get_person/%s/" % person.pk)
  201. self.assertNumQueries(2, test_func)
  202. class AssertNumQueriesUponConnectionTests(TransactionTestCase):
  203. available_apps = []
  204. def test_ignores_connection_configuration_queries(self):
  205. real_ensure_connection = connection.ensure_connection
  206. connection.close()
  207. def make_configuration_query():
  208. is_opening_connection = connection.connection is None
  209. real_ensure_connection()
  210. if is_opening_connection:
  211. # Avoid infinite recursion. Creating a cursor calls
  212. # ensure_connection() which is currently mocked by this method.
  213. with connection.cursor() as cursor:
  214. cursor.execute("SELECT 1" + connection.features.bare_select_suffix)
  215. ensure_connection = (
  216. "django.db.backends.base.base.BaseDatabaseWrapper.ensure_connection"
  217. )
  218. with mock.patch(ensure_connection, side_effect=make_configuration_query):
  219. with self.assertNumQueries(1):
  220. list(Car.objects.all())
  221. class AssertQuerySetEqualTests(TestCase):
  222. @classmethod
  223. def setUpTestData(cls):
  224. cls.p1 = Person.objects.create(name="p1")
  225. cls.p2 = Person.objects.create(name="p2")
  226. def test_empty(self):
  227. self.assertQuerySetEqual(Person.objects.filter(name="p3"), [])
  228. def test_ordered(self):
  229. self.assertQuerySetEqual(
  230. Person.objects.order_by("name"),
  231. [self.p1, self.p2],
  232. )
  233. def test_unordered(self):
  234. self.assertQuerySetEqual(
  235. Person.objects.order_by("name"), [self.p2, self.p1], ordered=False
  236. )
  237. def test_queryset(self):
  238. self.assertQuerySetEqual(
  239. Person.objects.order_by("name"),
  240. Person.objects.order_by("name"),
  241. )
  242. def test_flat_values_list(self):
  243. self.assertQuerySetEqual(
  244. Person.objects.order_by("name").values_list("name", flat=True),
  245. ["p1", "p2"],
  246. )
  247. def test_transform(self):
  248. self.assertQuerySetEqual(
  249. Person.objects.order_by("name"),
  250. [self.p1.pk, self.p2.pk],
  251. transform=lambda x: x.pk,
  252. )
  253. def test_repr_transform(self):
  254. self.assertQuerySetEqual(
  255. Person.objects.order_by("name"),
  256. [repr(self.p1), repr(self.p2)],
  257. transform=repr,
  258. )
  259. def test_undefined_order(self):
  260. # Using an unordered queryset with more than one ordered value
  261. # is an error.
  262. msg = (
  263. "Trying to compare non-ordered queryset against more than one "
  264. "ordered value."
  265. )
  266. with self.assertRaisesMessage(ValueError, msg):
  267. self.assertQuerySetEqual(
  268. Person.objects.all(),
  269. [self.p1, self.p2],
  270. )
  271. # No error for one value.
  272. self.assertQuerySetEqual(Person.objects.filter(name="p1"), [self.p1])
  273. def test_repeated_values(self):
  274. """
  275. assertQuerySetEqual checks the number of appearance of each item
  276. when used with option ordered=False.
  277. """
  278. batmobile = Car.objects.create(name="Batmobile")
  279. k2000 = Car.objects.create(name="K 2000")
  280. PossessedCar.objects.bulk_create(
  281. [
  282. PossessedCar(car=batmobile, belongs_to=self.p1),
  283. PossessedCar(car=batmobile, belongs_to=self.p1),
  284. PossessedCar(car=k2000, belongs_to=self.p1),
  285. PossessedCar(car=k2000, belongs_to=self.p1),
  286. PossessedCar(car=k2000, belongs_to=self.p1),
  287. PossessedCar(car=k2000, belongs_to=self.p1),
  288. ]
  289. )
  290. with self.assertRaises(AssertionError):
  291. self.assertQuerySetEqual(
  292. self.p1.cars.all(), [batmobile, k2000], ordered=False
  293. )
  294. self.assertQuerySetEqual(
  295. self.p1.cars.all(), [batmobile] * 2 + [k2000] * 4, ordered=False
  296. )
  297. def test_maxdiff(self):
  298. names = ["Joe Smith %s" % i for i in range(20)]
  299. Person.objects.bulk_create([Person(name=name) for name in names])
  300. names.append("Extra Person")
  301. with self.assertRaises(AssertionError) as ctx:
  302. self.assertQuerySetEqual(
  303. Person.objects.filter(name__startswith="Joe"),
  304. names,
  305. ordered=False,
  306. transform=lambda p: p.name,
  307. )
  308. self.assertIn("Set self.maxDiff to None to see it.", str(ctx.exception))
  309. original = self.maxDiff
  310. self.maxDiff = None
  311. try:
  312. with self.assertRaises(AssertionError) as ctx:
  313. self.assertQuerySetEqual(
  314. Person.objects.filter(name__startswith="Joe"),
  315. names,
  316. ordered=False,
  317. transform=lambda p: p.name,
  318. )
  319. finally:
  320. self.maxDiff = original
  321. exception_msg = str(ctx.exception)
  322. self.assertNotIn("Set self.maxDiff to None to see it.", exception_msg)
  323. for name in names:
  324. self.assertIn(name, exception_msg)
  325. @override_settings(ROOT_URLCONF="test_utils.urls")
  326. class CaptureQueriesContextManagerTests(TestCase):
  327. @classmethod
  328. def setUpTestData(cls):
  329. cls.person_pk = str(Person.objects.create(name="test").pk)
  330. def test_simple(self):
  331. with CaptureQueriesContext(connection) as captured_queries:
  332. Person.objects.get(pk=self.person_pk)
  333. self.assertEqual(len(captured_queries), 1)
  334. self.assertIn(self.person_pk, captured_queries[0]["sql"])
  335. with CaptureQueriesContext(connection) as captured_queries:
  336. pass
  337. self.assertEqual(0, len(captured_queries))
  338. def test_within(self):
  339. with CaptureQueriesContext(connection) as captured_queries:
  340. Person.objects.get(pk=self.person_pk)
  341. self.assertEqual(len(captured_queries), 1)
  342. self.assertIn(self.person_pk, captured_queries[0]["sql"])
  343. def test_nested(self):
  344. with CaptureQueriesContext(connection) as captured_queries:
  345. Person.objects.count()
  346. with CaptureQueriesContext(connection) as nested_captured_queries:
  347. Person.objects.count()
  348. self.assertEqual(1, len(nested_captured_queries))
  349. self.assertEqual(2, len(captured_queries))
  350. def test_failure(self):
  351. with self.assertRaises(TypeError):
  352. with CaptureQueriesContext(connection):
  353. raise TypeError
  354. def test_with_client(self):
  355. with CaptureQueriesContext(connection) as captured_queries:
  356. self.client.get("/test_utils/get_person/%s/" % self.person_pk)
  357. self.assertEqual(len(captured_queries), 1)
  358. self.assertIn(self.person_pk, captured_queries[0]["sql"])
  359. with CaptureQueriesContext(connection) as captured_queries:
  360. self.client.get("/test_utils/get_person/%s/" % self.person_pk)
  361. self.assertEqual(len(captured_queries), 1)
  362. self.assertIn(self.person_pk, captured_queries[0]["sql"])
  363. with CaptureQueriesContext(connection) as captured_queries:
  364. self.client.get("/test_utils/get_person/%s/" % self.person_pk)
  365. self.client.get("/test_utils/get_person/%s/" % self.person_pk)
  366. self.assertEqual(len(captured_queries), 2)
  367. self.assertIn(self.person_pk, captured_queries[0]["sql"])
  368. self.assertIn(self.person_pk, captured_queries[1]["sql"])
  369. @override_settings(ROOT_URLCONF="test_utils.urls")
  370. class AssertNumQueriesContextManagerTests(TestCase):
  371. def test_simple(self):
  372. with self.assertNumQueries(0):
  373. pass
  374. with self.assertNumQueries(1):
  375. Person.objects.count()
  376. with self.assertNumQueries(2):
  377. Person.objects.count()
  378. Person.objects.count()
  379. def test_failure(self):
  380. msg = "1 != 2 : 1 queries executed, 2 expected\nCaptured queries were:\n1."
  381. with self.assertRaisesMessage(AssertionError, msg):
  382. with self.assertNumQueries(2):
  383. Person.objects.count()
  384. with self.assertRaises(TypeError):
  385. with self.assertNumQueries(4000):
  386. raise TypeError
  387. def test_with_client(self):
  388. person = Person.objects.create(name="test")
  389. with self.assertNumQueries(1):
  390. self.client.get("/test_utils/get_person/%s/" % person.pk)
  391. with self.assertNumQueries(1):
  392. self.client.get("/test_utils/get_person/%s/" % person.pk)
  393. with self.assertNumQueries(2):
  394. self.client.get("/test_utils/get_person/%s/" % person.pk)
  395. self.client.get("/test_utils/get_person/%s/" % person.pk)
  396. @override_settings(ROOT_URLCONF="test_utils.urls")
  397. class AssertTemplateUsedContextManagerTests(SimpleTestCase):
  398. def test_usage(self):
  399. with self.assertTemplateUsed("template_used/base.html"):
  400. render_to_string("template_used/base.html")
  401. with self.assertTemplateUsed(template_name="template_used/base.html"):
  402. render_to_string("template_used/base.html")
  403. with self.assertTemplateUsed("template_used/base.html"):
  404. render_to_string("template_used/include.html")
  405. with self.assertTemplateUsed("template_used/base.html"):
  406. render_to_string("template_used/extends.html")
  407. with self.assertTemplateUsed("template_used/base.html"):
  408. render_to_string("template_used/base.html")
  409. render_to_string("template_used/base.html")
  410. def test_nested_usage(self):
  411. with self.assertTemplateUsed("template_used/base.html"):
  412. with self.assertTemplateUsed("template_used/include.html"):
  413. render_to_string("template_used/include.html")
  414. with self.assertTemplateUsed("template_used/extends.html"):
  415. with self.assertTemplateUsed("template_used/base.html"):
  416. render_to_string("template_used/extends.html")
  417. with self.assertTemplateUsed("template_used/base.html"):
  418. with self.assertTemplateUsed("template_used/alternative.html"):
  419. render_to_string("template_used/alternative.html")
  420. render_to_string("template_used/base.html")
  421. with self.assertTemplateUsed("template_used/base.html"):
  422. render_to_string("template_used/extends.html")
  423. with self.assertTemplateNotUsed("template_used/base.html"):
  424. render_to_string("template_used/alternative.html")
  425. render_to_string("template_used/base.html")
  426. def test_not_used(self):
  427. with self.assertTemplateNotUsed("template_used/base.html"):
  428. pass
  429. with self.assertTemplateNotUsed("template_used/alternative.html"):
  430. pass
  431. def test_error_message(self):
  432. msg = "No templates used to render the response"
  433. with self.assertRaisesMessage(AssertionError, msg):
  434. with self.assertTemplateUsed("template_used/base.html"):
  435. pass
  436. with self.assertRaisesMessage(AssertionError, msg):
  437. with self.assertTemplateUsed(template_name="template_used/base.html"):
  438. pass
  439. msg2 = (
  440. "Template 'template_used/base.html' was not a template used to render "
  441. "the response. Actual template(s) used: template_used/alternative.html"
  442. )
  443. with self.assertRaisesMessage(AssertionError, msg2):
  444. with self.assertTemplateUsed("template_used/base.html"):
  445. render_to_string("template_used/alternative.html")
  446. msg = "No templates used to render the response"
  447. with self.assertRaisesMessage(AssertionError, msg):
  448. response = self.client.get("/test_utils/no_template_used/")
  449. self.assertTemplateUsed(response, "template_used/base.html")
  450. with self.assertRaisesMessage(AssertionError, msg):
  451. with self.assertTemplateUsed("template_used/base.html"):
  452. self.client.get("/test_utils/no_template_used/")
  453. with self.assertRaisesMessage(AssertionError, msg):
  454. with self.assertTemplateUsed("template_used/base.html"):
  455. template = Template("template_used/alternative.html", name=None)
  456. template.render(Context())
  457. def test_msg_prefix(self):
  458. msg_prefix = "Prefix"
  459. msg = f"{msg_prefix}: No templates used to render the response"
  460. with self.assertRaisesMessage(AssertionError, msg):
  461. with self.assertTemplateUsed(
  462. "template_used/base.html", msg_prefix=msg_prefix
  463. ):
  464. pass
  465. with self.assertRaisesMessage(AssertionError, msg):
  466. with self.assertTemplateUsed(
  467. template_name="template_used/base.html",
  468. msg_prefix=msg_prefix,
  469. ):
  470. pass
  471. msg = (
  472. f"{msg_prefix}: Template 'template_used/base.html' was not a "
  473. f"template used to render the response. Actual template(s) used: "
  474. f"template_used/alternative.html"
  475. )
  476. with self.assertRaisesMessage(AssertionError, msg):
  477. with self.assertTemplateUsed(
  478. "template_used/base.html", msg_prefix=msg_prefix
  479. ):
  480. render_to_string("template_used/alternative.html")
  481. def test_count(self):
  482. with self.assertTemplateUsed("template_used/base.html", count=2):
  483. render_to_string("template_used/base.html")
  484. render_to_string("template_used/base.html")
  485. msg = (
  486. "Template 'template_used/base.html' was expected to be rendered "
  487. "3 time(s) but was actually rendered 2 time(s)."
  488. )
  489. with self.assertRaisesMessage(AssertionError, msg):
  490. with self.assertTemplateUsed("template_used/base.html", count=3):
  491. render_to_string("template_used/base.html")
  492. render_to_string("template_used/base.html")
  493. def test_failure(self):
  494. msg = "response and/or template_name argument must be provided"
  495. with self.assertRaisesMessage(TypeError, msg):
  496. with self.assertTemplateUsed():
  497. pass
  498. msg = "No templates used to render the response"
  499. with self.assertRaisesMessage(AssertionError, msg):
  500. with self.assertTemplateUsed(""):
  501. pass
  502. with self.assertRaisesMessage(AssertionError, msg):
  503. with self.assertTemplateUsed(""):
  504. render_to_string("template_used/base.html")
  505. with self.assertRaisesMessage(AssertionError, msg):
  506. with self.assertTemplateUsed(template_name=""):
  507. pass
  508. msg = (
  509. "Template 'template_used/base.html' was not a template used to "
  510. "render the response. Actual template(s) used: "
  511. "template_used/alternative.html"
  512. )
  513. with self.assertRaisesMessage(AssertionError, msg):
  514. with self.assertTemplateUsed("template_used/base.html"):
  515. render_to_string("template_used/alternative.html")
  516. def test_assert_used_on_http_response(self):
  517. response = HttpResponse()
  518. msg = "%s() is only usable on responses fetched using the Django test Client."
  519. with self.assertRaisesMessage(ValueError, msg % "assertTemplateUsed"):
  520. self.assertTemplateUsed(response, "template.html")
  521. with self.assertRaisesMessage(ValueError, msg % "assertTemplateNotUsed"):
  522. self.assertTemplateNotUsed(response, "template.html")
  523. class HTMLEqualTests(SimpleTestCase):
  524. def test_html_parser(self):
  525. element = parse_html("<div><p>Hello</p></div>")
  526. self.assertEqual(len(element.children), 1)
  527. self.assertEqual(element.children[0].name, "p")
  528. self.assertEqual(element.children[0].children[0], "Hello")
  529. parse_html("<p>")
  530. parse_html("<p attr>")
  531. dom = parse_html("<p>foo")
  532. self.assertEqual(len(dom.children), 1)
  533. self.assertEqual(dom.name, "p")
  534. self.assertEqual(dom[0], "foo")
  535. def test_parse_html_in_script(self):
  536. parse_html('<script>var a = "<p" + ">";</script>')
  537. parse_html(
  538. """
  539. <script>
  540. var js_sha_link='<p>***</p>';
  541. </script>
  542. """
  543. )
  544. # script content will be parsed to text
  545. dom = parse_html(
  546. """
  547. <script><p>foo</p> '</scr'+'ipt>' <span>bar</span></script>
  548. """
  549. )
  550. self.assertEqual(len(dom.children), 1)
  551. self.assertEqual(dom.children[0], "<p>foo</p> '</scr'+'ipt>' <span>bar</span>")
  552. def test_void_elements(self):
  553. for tag in VOID_ELEMENTS:
  554. with self.subTest(tag):
  555. dom = parse_html("<p>Hello <%s> world</p>" % tag)
  556. self.assertEqual(len(dom.children), 3)
  557. self.assertEqual(dom[0], "Hello")
  558. self.assertEqual(dom[1].name, tag)
  559. self.assertEqual(dom[2], "world")
  560. dom = parse_html("<p>Hello <%s /> world</p>" % tag)
  561. self.assertEqual(len(dom.children), 3)
  562. self.assertEqual(dom[0], "Hello")
  563. self.assertEqual(dom[1].name, tag)
  564. self.assertEqual(dom[2], "world")
  565. def test_simple_equal_html(self):
  566. self.assertHTMLEqual("", "")
  567. self.assertHTMLEqual("<p></p>", "<p></p>")
  568. self.assertHTMLEqual("<p></p>", " <p> </p> ")
  569. self.assertHTMLEqual("<div><p>Hello</p></div>", "<div><p>Hello</p></div>")
  570. self.assertHTMLEqual("<div><p>Hello</p></div>", "<div> <p>Hello</p> </div>")
  571. self.assertHTMLEqual("<div>\n<p>Hello</p></div>", "<div><p>Hello</p></div>\n")
  572. self.assertHTMLEqual(
  573. "<div><p>Hello\nWorld !</p></div>", "<div><p>Hello World\n!</p></div>"
  574. )
  575. self.assertHTMLEqual(
  576. "<div><p>Hello\nWorld !</p></div>", "<div><p>Hello World\n!</p></div>"
  577. )
  578. self.assertHTMLEqual("<p>Hello World !</p>", "<p>Hello World\n\n!</p>")
  579. self.assertHTMLEqual("<p> </p>", "<p></p>")
  580. self.assertHTMLEqual("<p/>", "<p></p>")
  581. self.assertHTMLEqual("<p />", "<p></p>")
  582. self.assertHTMLEqual("<input checked>", '<input checked="checked">')
  583. self.assertHTMLEqual("<p>Hello", "<p> Hello")
  584. self.assertHTMLEqual("<p>Hello</p>World", "<p>Hello</p> World")
  585. def test_ignore_comments(self):
  586. self.assertHTMLEqual(
  587. "<div>Hello<!-- this is a comment --> World!</div>",
  588. "<div>Hello World!</div>",
  589. )
  590. def test_unequal_html(self):
  591. self.assertHTMLNotEqual("<p>Hello</p>", "<p>Hello!</p>")
  592. self.assertHTMLNotEqual("<p>foo&#20;bar</p>", "<p>foo&nbsp;bar</p>")
  593. self.assertHTMLNotEqual("<p>foo bar</p>", "<p>foo &nbsp;bar</p>")
  594. self.assertHTMLNotEqual("<p>foo nbsp</p>", "<p>foo &nbsp;</p>")
  595. self.assertHTMLNotEqual("<p>foo #20</p>", "<p>foo &#20;</p>")
  596. self.assertHTMLNotEqual(
  597. "<p><span>Hello</span><span>World</span></p>",
  598. "<p><span>Hello</span>World</p>",
  599. )
  600. self.assertHTMLNotEqual(
  601. "<p><span>Hello</span>World</p>",
  602. "<p><span>Hello</span><span>World</span></p>",
  603. )
  604. def test_attributes(self):
  605. self.assertHTMLEqual(
  606. '<input type="text" id="id_name" />', '<input id="id_name" type="text" />'
  607. )
  608. self.assertHTMLEqual(
  609. """<input type='text' id="id_name" />""",
  610. '<input id="id_name" type="text" />',
  611. )
  612. self.assertHTMLNotEqual(
  613. '<input type="text" id="id_name" />',
  614. '<input type="password" id="id_name" />',
  615. )
  616. def test_class_attribute(self):
  617. pairs = [
  618. ('<p class="foo bar"></p>', '<p class="bar foo"></p>'),
  619. ('<p class=" foo bar "></p>', '<p class="bar foo"></p>'),
  620. ('<p class=" foo bar "></p>', '<p class="bar foo"></p>'),
  621. ('<p class="foo\tbar"></p>', '<p class="bar foo"></p>'),
  622. ('<p class="\tfoo\tbar\t"></p>', '<p class="bar foo"></p>'),
  623. ('<p class="\t\t\tfoo\t\t\tbar\t\t\t"></p>', '<p class="bar foo"></p>'),
  624. ('<p class="\t \nfoo \t\nbar\n\t "></p>', '<p class="bar foo"></p>'),
  625. ]
  626. for html1, html2 in pairs:
  627. with self.subTest(html1):
  628. self.assertHTMLEqual(html1, html2)
  629. def test_boolean_attribute(self):
  630. html1 = "<input checked>"
  631. html2 = '<input checked="">'
  632. html3 = '<input checked="checked">'
  633. self.assertHTMLEqual(html1, html2)
  634. self.assertHTMLEqual(html1, html3)
  635. self.assertHTMLEqual(html2, html3)
  636. self.assertHTMLNotEqual(html1, '<input checked="invalid">')
  637. self.assertEqual(str(parse_html(html1)), "<input checked>")
  638. self.assertEqual(str(parse_html(html2)), "<input checked>")
  639. self.assertEqual(str(parse_html(html3)), "<input checked>")
  640. def test_non_boolean_attibutes(self):
  641. html1 = "<input value>"
  642. html2 = '<input value="">'
  643. html3 = '<input value="value">'
  644. self.assertHTMLEqual(html1, html2)
  645. self.assertHTMLNotEqual(html1, html3)
  646. self.assertEqual(str(parse_html(html1)), '<input value="">')
  647. self.assertEqual(str(parse_html(html2)), '<input value="">')
  648. def test_normalize_refs(self):
  649. pairs = [
  650. ("&#39;", "&#x27;"),
  651. ("&#39;", "'"),
  652. ("&#x27;", "&#39;"),
  653. ("&#x27;", "'"),
  654. ("'", "&#39;"),
  655. ("'", "&#x27;"),
  656. ("&amp;", "&#38;"),
  657. ("&amp;", "&#x26;"),
  658. ("&amp;", "&"),
  659. ("&#38;", "&amp;"),
  660. ("&#38;", "&#x26;"),
  661. ("&#38;", "&"),
  662. ("&#x26;", "&amp;"),
  663. ("&#x26;", "&#38;"),
  664. ("&#x26;", "&"),
  665. ("&", "&amp;"),
  666. ("&", "&#38;"),
  667. ("&", "&#x26;"),
  668. ]
  669. for pair in pairs:
  670. with self.subTest(repr(pair)):
  671. self.assertHTMLEqual(*pair)
  672. def test_complex_examples(self):
  673. self.assertHTMLEqual(
  674. """<tr><th><label for="id_first_name">First name:</label></th>
  675. <td><input type="text" name="first_name" value="John" id="id_first_name" /></td></tr>
  676. <tr><th><label for="id_last_name">Last name:</label></th>
  677. <td><input type="text" id="id_last_name" name="last_name" value="Lennon" /></td></tr>
  678. <tr><th><label for="id_birthday">Birthday:</label></th>
  679. <td><input type="text" value="1940-10-9" name="birthday" id="id_birthday" /></td></tr>""", # NOQA
  680. """
  681. <tr><th>
  682. <label for="id_first_name">First name:</label></th><td>
  683. <input type="text" name="first_name" value="John" id="id_first_name" />
  684. </td></tr>
  685. <tr><th>
  686. <label for="id_last_name">Last name:</label></th><td>
  687. <input type="text" name="last_name" value="Lennon" id="id_last_name" />
  688. </td></tr>
  689. <tr><th>
  690. <label for="id_birthday">Birthday:</label></th><td>
  691. <input type="text" name="birthday" value="1940-10-9" id="id_birthday" />
  692. </td></tr>
  693. """,
  694. )
  695. self.assertHTMLEqual(
  696. """<!DOCTYPE html>
  697. <html>
  698. <head>
  699. <link rel="stylesheet">
  700. <title>Document</title>
  701. <meta attribute="value">
  702. </head>
  703. <body>
  704. <p>
  705. This is a valid paragraph
  706. <div> this is a div AFTER the p</div>
  707. </body>
  708. </html>""",
  709. """
  710. <html>
  711. <head>
  712. <link rel="stylesheet">
  713. <title>Document</title>
  714. <meta attribute="value">
  715. </head>
  716. <body>
  717. <p> This is a valid paragraph
  718. <!-- browsers would close the p tag here -->
  719. <div> this is a div AFTER the p</div>
  720. </p> <!-- this is invalid HTML parsing, but it should make no
  721. difference in most cases -->
  722. </body>
  723. </html>""",
  724. )
  725. def test_html_contain(self):
  726. # equal html contains each other
  727. dom1 = parse_html("<p>foo")
  728. dom2 = parse_html("<p>foo</p>")
  729. self.assertIn(dom1, dom2)
  730. self.assertIn(dom2, dom1)
  731. dom2 = parse_html("<div><p>foo</p></div>")
  732. self.assertIn(dom1, dom2)
  733. self.assertNotIn(dom2, dom1)
  734. self.assertNotIn("<p>foo</p>", dom2)
  735. self.assertIn("foo", dom2)
  736. # when a root element is used ...
  737. dom1 = parse_html("<p>foo</p><p>bar</p>")
  738. dom2 = parse_html("<p>foo</p><p>bar</p>")
  739. self.assertIn(dom1, dom2)
  740. dom1 = parse_html("<p>foo</p>")
  741. self.assertIn(dom1, dom2)
  742. dom1 = parse_html("<p>bar</p>")
  743. self.assertIn(dom1, dom2)
  744. dom1 = parse_html("<div><p>foo</p><p>bar</p></div>")
  745. self.assertIn(dom2, dom1)
  746. def test_count(self):
  747. # equal html contains each other one time
  748. dom1 = parse_html("<p>foo")
  749. dom2 = parse_html("<p>foo</p>")
  750. self.assertEqual(dom1.count(dom2), 1)
  751. self.assertEqual(dom2.count(dom1), 1)
  752. dom2 = parse_html("<p>foo</p><p>bar</p>")
  753. self.assertEqual(dom2.count(dom1), 1)
  754. dom2 = parse_html("<p>foo foo</p><p>foo</p>")
  755. self.assertEqual(dom2.count("foo"), 3)
  756. dom2 = parse_html('<p class="bar">foo</p>')
  757. self.assertEqual(dom2.count("bar"), 0)
  758. self.assertEqual(dom2.count("class"), 0)
  759. self.assertEqual(dom2.count("p"), 0)
  760. self.assertEqual(dom2.count("o"), 2)
  761. dom2 = parse_html("<p>foo</p><p>foo</p>")
  762. self.assertEqual(dom2.count(dom1), 2)
  763. dom2 = parse_html('<div><p>foo<input type=""></p><p>foo</p></div>')
  764. self.assertEqual(dom2.count(dom1), 1)
  765. dom2 = parse_html("<div><div><p>foo</p></div></div>")
  766. self.assertEqual(dom2.count(dom1), 1)
  767. dom2 = parse_html("<p>foo<p>foo</p></p>")
  768. self.assertEqual(dom2.count(dom1), 1)
  769. dom2 = parse_html("<p>foo<p>bar</p></p>")
  770. self.assertEqual(dom2.count(dom1), 0)
  771. # HTML with a root element contains the same HTML with no root element.
  772. dom1 = parse_html("<p>foo</p><p>bar</p>")
  773. dom2 = parse_html("<div><p>foo</p><p>bar</p></div>")
  774. self.assertEqual(dom2.count(dom1), 1)
  775. # Target of search is a sequence of child elements and appears more
  776. # than once.
  777. dom2 = parse_html("<div><p>foo</p><p>bar</p><p>foo</p><p>bar</p></div>")
  778. self.assertEqual(dom2.count(dom1), 2)
  779. # Searched HTML has additional children.
  780. dom1 = parse_html("<a/><b/>")
  781. dom2 = parse_html("<a/><b/><c/>")
  782. self.assertEqual(dom2.count(dom1), 1)
  783. # No match found in children.
  784. dom1 = parse_html("<b/><a/>")
  785. self.assertEqual(dom2.count(dom1), 0)
  786. # Target of search found among children and grandchildren.
  787. dom1 = parse_html("<b/><b/>")
  788. dom2 = parse_html("<a><b/><b/></a><b/><b/>")
  789. self.assertEqual(dom2.count(dom1), 2)
  790. def test_root_element_escaped_html(self):
  791. html = "&lt;br&gt;"
  792. parsed = parse_html(html)
  793. self.assertEqual(str(parsed), html)
  794. def test_parsing_errors(self):
  795. with self.assertRaises(AssertionError):
  796. self.assertHTMLEqual("<p>", "")
  797. with self.assertRaises(AssertionError):
  798. self.assertHTMLEqual("", "<p>")
  799. error_msg = (
  800. "First argument is not valid HTML:\n"
  801. "('Unexpected end tag `div` (Line 1, Column 6)', (1, 6))"
  802. )
  803. with self.assertRaisesMessage(AssertionError, error_msg):
  804. self.assertHTMLEqual("< div></ div>", "<div></div>")
  805. with self.assertRaises(HTMLParseError):
  806. parse_html("</p>")
  807. def test_escaped_html_errors(self):
  808. msg = "<p>\n<foo>\n</p> != <p>\n&lt;foo&gt;\n</p>\n"
  809. with self.assertRaisesMessage(AssertionError, msg):
  810. self.assertHTMLEqual("<p><foo></p>", "<p>&lt;foo&gt;</p>")
  811. with self.assertRaisesMessage(AssertionError, msg):
  812. self.assertHTMLEqual("<p><foo></p>", "<p>&#60;foo&#62;</p>")
  813. def test_contains_html(self):
  814. response = HttpResponse(
  815. """<body>
  816. This is a form: <form method="get">
  817. <input type="text" name="Hello" />
  818. </form></body>"""
  819. )
  820. self.assertNotContains(response, "<input name='Hello' type='text'>")
  821. self.assertContains(response, '<form method="get">')
  822. self.assertContains(response, "<input name='Hello' type='text'>", html=True)
  823. self.assertNotContains(response, '<form method="get">', html=True)
  824. invalid_response = HttpResponse("""<body <bad>>""")
  825. with self.assertRaises(AssertionError):
  826. self.assertContains(invalid_response, "<p></p>")
  827. with self.assertRaises(AssertionError):
  828. self.assertContains(response, '<p "whats" that>')
  829. def test_unicode_handling(self):
  830. response = HttpResponse(
  831. '<p class="help">Some help text for the title (with Unicode ŠĐĆŽćžšđ)</p>'
  832. )
  833. self.assertContains(
  834. response,
  835. '<p class="help">Some help text for the title (with Unicode ŠĐĆŽćžšđ)</p>',
  836. html=True,
  837. )
  838. class InHTMLTests(SimpleTestCase):
  839. def test_needle_msg(self):
  840. msg = (
  841. "False is not true : Couldn't find '<b>Hello</b>' in the following "
  842. "response\n'<p>Test</p>'"
  843. )
  844. with self.assertRaisesMessage(AssertionError, msg):
  845. self.assertInHTML("<b>Hello</b>", "<p>Test</p>")
  846. def test_msg_prefix(self):
  847. msg = (
  848. "False is not true : Prefix: Couldn't find '<b>Hello</b>' in the following "
  849. 'response\n\'<input type="text" name="Hello" />\''
  850. )
  851. with self.assertRaisesMessage(AssertionError, msg):
  852. self.assertInHTML(
  853. "<b>Hello</b>",
  854. '<input type="text" name="Hello" />',
  855. msg_prefix="Prefix",
  856. )
  857. def test_count_msg_prefix(self):
  858. msg = (
  859. "2 != 1 : Prefix: Found 2 instances of '<b>Hello</b>' (expected 1) in the "
  860. "following response\n'<b>Hello</b><b>Hello</b>'"
  861. ""
  862. )
  863. with self.assertRaisesMessage(AssertionError, msg):
  864. self.assertInHTML(
  865. "<b>Hello</b>",
  866. "<b>Hello</b><b>Hello</b>",
  867. count=1,
  868. msg_prefix="Prefix",
  869. )
  870. def test_base(self):
  871. haystack = "<p><b>Hello</b> <span>there</span>! Hi <span>there</span>!</p>"
  872. self.assertInHTML("<b>Hello</b>", haystack=haystack)
  873. msg = f"Couldn't find '<p>Howdy</p>' in the following response\n{haystack!r}"
  874. with self.assertRaisesMessage(AssertionError, msg):
  875. self.assertInHTML("<p>Howdy</p>", haystack)
  876. self.assertInHTML("<span>there</span>", haystack=haystack, count=2)
  877. msg = (
  878. "Found 1 instances of '<b>Hello</b>' (expected 2) in the following response"
  879. f"\n{haystack!r}"
  880. )
  881. with self.assertRaisesMessage(AssertionError, msg):
  882. self.assertInHTML("<b>Hello</b>", haystack=haystack, count=2)
  883. def test_long_haystack(self):
  884. haystack = (
  885. "<p>This is a very very very very very very very very long message which "
  886. "exceedes the max limit of truncation.</p>"
  887. )
  888. msg = f"Couldn't find '<b>Hello</b>' in the following response\n{haystack!r}"
  889. with self.assertRaisesMessage(AssertionError, msg):
  890. self.assertInHTML("<b>Hello</b>", haystack)
  891. msg = (
  892. "Found 0 instances of '<b>This</b>' (expected 3) in the following response"
  893. f"\n{haystack!r}"
  894. )
  895. with self.assertRaisesMessage(AssertionError, msg):
  896. self.assertInHTML("<b>This</b>", haystack, 3)
  897. def test_assert_not_in_html(self):
  898. haystack = "<p><b>Hello</b> <span>there</span>! Hi <span>there</span>!</p>"
  899. self.assertNotInHTML("<b>Hi</b>", haystack=haystack)
  900. msg = (
  901. "'<b>Hello</b>' unexpectedly found in the following response"
  902. f"\n{haystack!r}"
  903. )
  904. with self.assertRaisesMessage(AssertionError, msg):
  905. self.assertNotInHTML("<b>Hello</b>", haystack=haystack)
  906. class JSONEqualTests(SimpleTestCase):
  907. def test_simple_equal(self):
  908. json1 = '{"attr1": "foo", "attr2":"baz"}'
  909. json2 = '{"attr1": "foo", "attr2":"baz"}'
  910. self.assertJSONEqual(json1, json2)
  911. def test_simple_equal_unordered(self):
  912. json1 = '{"attr1": "foo", "attr2":"baz"}'
  913. json2 = '{"attr2":"baz", "attr1": "foo"}'
  914. self.assertJSONEqual(json1, json2)
  915. def test_simple_equal_raise(self):
  916. json1 = '{"attr1": "foo", "attr2":"baz"}'
  917. json2 = '{"attr2":"baz"}'
  918. with self.assertRaises(AssertionError):
  919. self.assertJSONEqual(json1, json2)
  920. def test_equal_parsing_errors(self):
  921. invalid_json = '{"attr1": "foo, "attr2":"baz"}'
  922. valid_json = '{"attr1": "foo", "attr2":"baz"}'
  923. with self.assertRaises(AssertionError):
  924. self.assertJSONEqual(invalid_json, valid_json)
  925. with self.assertRaises(AssertionError):
  926. self.assertJSONEqual(valid_json, invalid_json)
  927. def test_simple_not_equal(self):
  928. json1 = '{"attr1": "foo", "attr2":"baz"}'
  929. json2 = '{"attr2":"baz"}'
  930. self.assertJSONNotEqual(json1, json2)
  931. def test_simple_not_equal_raise(self):
  932. json1 = '{"attr1": "foo", "attr2":"baz"}'
  933. json2 = '{"attr1": "foo", "attr2":"baz"}'
  934. with self.assertRaises(AssertionError):
  935. self.assertJSONNotEqual(json1, json2)
  936. def test_not_equal_parsing_errors(self):
  937. invalid_json = '{"attr1": "foo, "attr2":"baz"}'
  938. valid_json = '{"attr1": "foo", "attr2":"baz"}'
  939. with self.assertRaises(AssertionError):
  940. self.assertJSONNotEqual(invalid_json, valid_json)
  941. with self.assertRaises(AssertionError):
  942. self.assertJSONNotEqual(valid_json, invalid_json)
  943. class XMLEqualTests(SimpleTestCase):
  944. def test_simple_equal(self):
  945. xml1 = "<elem attr1='a' attr2='b' />"
  946. xml2 = "<elem attr1='a' attr2='b' />"
  947. self.assertXMLEqual(xml1, xml2)
  948. def test_simple_equal_unordered(self):
  949. xml1 = "<elem attr1='a' attr2='b' />"
  950. xml2 = "<elem attr2='b' attr1='a' />"
  951. self.assertXMLEqual(xml1, xml2)
  952. def test_simple_equal_raise(self):
  953. xml1 = "<elem attr1='a' />"
  954. xml2 = "<elem attr2='b' attr1='a' />"
  955. with self.assertRaises(AssertionError):
  956. self.assertXMLEqual(xml1, xml2)
  957. def test_simple_equal_raises_message(self):
  958. xml1 = "<elem attr1='a' />"
  959. xml2 = "<elem attr2='b' attr1='a' />"
  960. msg = """{xml1} != {xml2}
  961. - <elem attr1='a' />
  962. + <elem attr2='b' attr1='a' />
  963. ? ++++++++++
  964. """.format(
  965. xml1=repr(xml1), xml2=repr(xml2)
  966. )
  967. with self.assertRaisesMessage(AssertionError, msg):
  968. self.assertXMLEqual(xml1, xml2)
  969. def test_simple_not_equal(self):
  970. xml1 = "<elem attr1='a' attr2='c' />"
  971. xml2 = "<elem attr1='a' attr2='b' />"
  972. self.assertXMLNotEqual(xml1, xml2)
  973. def test_simple_not_equal_raise(self):
  974. xml1 = "<elem attr1='a' attr2='b' />"
  975. xml2 = "<elem attr2='b' attr1='a' />"
  976. with self.assertRaises(AssertionError):
  977. self.assertXMLNotEqual(xml1, xml2)
  978. def test_parsing_errors(self):
  979. xml_unvalid = "<elem attr1='a attr2='b' />"
  980. xml2 = "<elem attr2='b' attr1='a' />"
  981. with self.assertRaises(AssertionError):
  982. self.assertXMLNotEqual(xml_unvalid, xml2)
  983. def test_comment_root(self):
  984. xml1 = "<?xml version='1.0'?><!-- comment1 --><elem attr1='a' attr2='b' />"
  985. xml2 = "<?xml version='1.0'?><!-- comment2 --><elem attr2='b' attr1='a' />"
  986. self.assertXMLEqual(xml1, xml2)
  987. def test_simple_equal_with_leading_or_trailing_whitespace(self):
  988. xml1 = "<elem>foo</elem> \t\n"
  989. xml2 = " \t\n<elem>foo</elem>"
  990. self.assertXMLEqual(xml1, xml2)
  991. def test_simple_not_equal_with_whitespace_in_the_middle(self):
  992. xml1 = "<elem>foo</elem><elem>bar</elem>"
  993. xml2 = "<elem>foo</elem> <elem>bar</elem>"
  994. self.assertXMLNotEqual(xml1, xml2)
  995. def test_doctype_root(self):
  996. xml1 = '<?xml version="1.0"?><!DOCTYPE root SYSTEM "example1.dtd"><root />'
  997. xml2 = '<?xml version="1.0"?><!DOCTYPE root SYSTEM "example2.dtd"><root />'
  998. self.assertXMLEqual(xml1, xml2)
  999. def test_processing_instruction(self):
  1000. xml1 = (
  1001. '<?xml version="1.0"?>'
  1002. '<?xml-model href="http://www.example1.com"?><root />'
  1003. )
  1004. xml2 = (
  1005. '<?xml version="1.0"?>'
  1006. '<?xml-model href="http://www.example2.com"?><root />'
  1007. )
  1008. self.assertXMLEqual(xml1, xml2)
  1009. self.assertXMLEqual(
  1010. '<?xml-stylesheet href="style1.xslt" type="text/xsl"?><root />',
  1011. '<?xml-stylesheet href="style2.xslt" type="text/xsl"?><root />',
  1012. )
  1013. class SkippingExtraTests(TestCase):
  1014. fixtures = ["should_not_be_loaded.json"]
  1015. # HACK: This depends on internals of our TestCase subclasses
  1016. def __call__(self, result=None):
  1017. # Detect fixture loading by counting SQL queries, should be zero
  1018. with self.assertNumQueries(0):
  1019. super().__call__(result)
  1020. @unittest.skip("Fixture loading should not be performed for skipped tests.")
  1021. def test_fixtures_are_skipped(self):
  1022. pass
  1023. class AssertRaisesMsgTest(SimpleTestCase):
  1024. def test_assert_raises_message(self):
  1025. msg = "'Expected message' not found in 'Unexpected message'"
  1026. # context manager form of assertRaisesMessage()
  1027. with self.assertRaisesMessage(AssertionError, msg):
  1028. with self.assertRaisesMessage(ValueError, "Expected message"):
  1029. raise ValueError("Unexpected message")
  1030. # callable form
  1031. def func():
  1032. raise ValueError("Unexpected message")
  1033. with self.assertRaisesMessage(AssertionError, msg):
  1034. self.assertRaisesMessage(ValueError, "Expected message", func)
  1035. def test_special_re_chars(self):
  1036. """assertRaisesMessage shouldn't interpret RE special chars."""
  1037. def func1():
  1038. raise ValueError("[.*x+]y?")
  1039. with self.assertRaisesMessage(ValueError, "[.*x+]y?"):
  1040. func1()
  1041. class AssertWarnsMessageTests(SimpleTestCase):
  1042. def test_context_manager(self):
  1043. with self.assertWarnsMessage(UserWarning, "Expected message"):
  1044. warnings.warn("Expected message", UserWarning)
  1045. def test_context_manager_failure(self):
  1046. msg = "Expected message' not found in 'Unexpected message'"
  1047. with self.assertRaisesMessage(AssertionError, msg):
  1048. with self.assertWarnsMessage(UserWarning, "Expected message"):
  1049. warnings.warn("Unexpected message", UserWarning)
  1050. def test_callable(self):
  1051. def func():
  1052. warnings.warn("Expected message", UserWarning)
  1053. self.assertWarnsMessage(UserWarning, "Expected message", func)
  1054. def test_special_re_chars(self):
  1055. def func1():
  1056. warnings.warn("[.*x+]y?", UserWarning)
  1057. with self.assertWarnsMessage(UserWarning, "[.*x+]y?"):
  1058. func1()
  1059. class AssertFieldOutputTests(SimpleTestCase):
  1060. def test_assert_field_output(self):
  1061. error_invalid = ["Enter a valid email address."]
  1062. self.assertFieldOutput(
  1063. EmailField, {"a@a.com": "a@a.com"}, {"aaa": error_invalid}
  1064. )
  1065. with self.assertRaises(AssertionError):
  1066. self.assertFieldOutput(
  1067. EmailField,
  1068. {"a@a.com": "a@a.com"},
  1069. {"aaa": error_invalid + ["Another error"]},
  1070. )
  1071. with self.assertRaises(AssertionError):
  1072. self.assertFieldOutput(
  1073. EmailField, {"a@a.com": "Wrong output"}, {"aaa": error_invalid}
  1074. )
  1075. with self.assertRaises(AssertionError):
  1076. self.assertFieldOutput(
  1077. EmailField,
  1078. {"a@a.com": "a@a.com"},
  1079. {"aaa": ["Come on, gimme some well formatted data, dude."]},
  1080. )
  1081. def test_custom_required_message(self):
  1082. class MyCustomField(IntegerField):
  1083. default_error_messages = {
  1084. "required": "This is really required.",
  1085. }
  1086. self.assertFieldOutput(MyCustomField, {}, {}, empty_value=None)
  1087. @override_settings(ROOT_URLCONF="test_utils.urls")
  1088. class AssertURLEqualTests(SimpleTestCase):
  1089. def test_equal(self):
  1090. valid_tests = (
  1091. ("http://example.com/?", "http://example.com/"),
  1092. ("http://example.com/?x=1&", "http://example.com/?x=1"),
  1093. ("http://example.com/?x=1&y=2", "http://example.com/?y=2&x=1"),
  1094. ("http://example.com/?x=1&y=2", "http://example.com/?y=2&x=1"),
  1095. (
  1096. "http://example.com/?x=1&y=2&a=1&a=2",
  1097. "http://example.com/?a=1&a=2&y=2&x=1",
  1098. ),
  1099. ("/path/to/?x=1&y=2&z=3", "/path/to/?z=3&y=2&x=1"),
  1100. ("?x=1&y=2&z=3", "?z=3&y=2&x=1"),
  1101. ("/test_utils/no_template_used/", reverse_lazy("no_template_used")),
  1102. )
  1103. for url1, url2 in valid_tests:
  1104. with self.subTest(url=url1):
  1105. self.assertURLEqual(url1, url2)
  1106. def test_not_equal(self):
  1107. invalid_tests = (
  1108. # Protocol must be the same.
  1109. ("http://example.com/", "https://example.com/"),
  1110. ("http://example.com/?x=1&x=2", "https://example.com/?x=2&x=1"),
  1111. ("http://example.com/?x=1&y=bar&x=2", "https://example.com/?y=bar&x=2&x=1"),
  1112. # Parameters of the same name must be in the same order.
  1113. ("/path/to?a=1&a=2", "/path/to/?a=2&a=1"),
  1114. )
  1115. for url1, url2 in invalid_tests:
  1116. with self.subTest(url=url1), self.assertRaises(AssertionError):
  1117. self.assertURLEqual(url1, url2)
  1118. def test_message(self):
  1119. msg = (
  1120. "Expected 'http://example.com/?x=1&x=2' to equal "
  1121. "'https://example.com/?x=2&x=1'"
  1122. )
  1123. with self.assertRaisesMessage(AssertionError, msg):
  1124. self.assertURLEqual(
  1125. "http://example.com/?x=1&x=2", "https://example.com/?x=2&x=1"
  1126. )
  1127. def test_msg_prefix(self):
  1128. msg = (
  1129. "Prefix: Expected 'http://example.com/?x=1&x=2' to equal "
  1130. "'https://example.com/?x=2&x=1'"
  1131. )
  1132. with self.assertRaisesMessage(AssertionError, msg):
  1133. self.assertURLEqual(
  1134. "http://example.com/?x=1&x=2",
  1135. "https://example.com/?x=2&x=1",
  1136. msg_prefix="Prefix",
  1137. )
  1138. class TestForm(Form):
  1139. field = CharField()
  1140. def clean_field(self):
  1141. value = self.cleaned_data.get("field", "")
  1142. if value == "invalid":
  1143. raise ValidationError("invalid value")
  1144. return value
  1145. def clean(self):
  1146. if self.cleaned_data.get("field") == "invalid_non_field":
  1147. raise ValidationError("non-field error")
  1148. return self.cleaned_data
  1149. @classmethod
  1150. def _get_cleaned_form(cls, field_value):
  1151. form = cls({"field": field_value})
  1152. form.full_clean()
  1153. return form
  1154. @classmethod
  1155. def valid(cls):
  1156. return cls._get_cleaned_form("valid")
  1157. @classmethod
  1158. def invalid(cls, nonfield=False):
  1159. return cls._get_cleaned_form("invalid_non_field" if nonfield else "invalid")
  1160. class TestFormset(formset_factory(TestForm)):
  1161. @classmethod
  1162. def _get_cleaned_formset(cls, field_value):
  1163. formset = cls(
  1164. {
  1165. "form-TOTAL_FORMS": "1",
  1166. "form-INITIAL_FORMS": "0",
  1167. "form-0-field": field_value,
  1168. }
  1169. )
  1170. formset.full_clean()
  1171. return formset
  1172. @classmethod
  1173. def valid(cls):
  1174. return cls._get_cleaned_formset("valid")
  1175. @classmethod
  1176. def invalid(cls, nonfield=False, nonform=False):
  1177. if nonform:
  1178. formset = cls({}, error_messages={"missing_management_form": "error"})
  1179. formset.full_clean()
  1180. return formset
  1181. return cls._get_cleaned_formset("invalid_non_field" if nonfield else "invalid")
  1182. class AssertFormErrorTests(SimpleTestCase):
  1183. def test_single_error(self):
  1184. self.assertFormError(TestForm.invalid(), "field", "invalid value")
  1185. def test_error_list(self):
  1186. self.assertFormError(TestForm.invalid(), "field", ["invalid value"])
  1187. def test_empty_errors_valid_form(self):
  1188. self.assertFormError(TestForm.valid(), "field", [])
  1189. def test_empty_errors_valid_form_non_field_errors(self):
  1190. self.assertFormError(TestForm.valid(), None, [])
  1191. def test_field_not_in_form(self):
  1192. msg = (
  1193. "The form <TestForm bound=True, valid=False, fields=(field)> does not "
  1194. "contain the field 'other_field'."
  1195. )
  1196. with self.assertRaisesMessage(AssertionError, msg):
  1197. self.assertFormError(TestForm.invalid(), "other_field", "invalid value")
  1198. msg_prefix = "Custom prefix"
  1199. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1200. self.assertFormError(
  1201. TestForm.invalid(),
  1202. "other_field",
  1203. "invalid value",
  1204. msg_prefix=msg_prefix,
  1205. )
  1206. def test_field_with_no_errors(self):
  1207. msg = (
  1208. "The errors of field 'field' on form <TestForm bound=True, valid=True, "
  1209. "fields=(field)> don't match."
  1210. )
  1211. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1212. self.assertFormError(TestForm.valid(), "field", "invalid value")
  1213. self.assertIn("[] != ['invalid value']", str(ctx.exception))
  1214. msg_prefix = "Custom prefix"
  1215. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1216. self.assertFormError(
  1217. TestForm.valid(), "field", "invalid value", msg_prefix=msg_prefix
  1218. )
  1219. def test_field_with_different_error(self):
  1220. msg = (
  1221. "The errors of field 'field' on form <TestForm bound=True, valid=False, "
  1222. "fields=(field)> don't match."
  1223. )
  1224. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1225. self.assertFormError(TestForm.invalid(), "field", "other error")
  1226. self.assertIn("['invalid value'] != ['other error']", str(ctx.exception))
  1227. msg_prefix = "Custom prefix"
  1228. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1229. self.assertFormError(
  1230. TestForm.invalid(), "field", "other error", msg_prefix=msg_prefix
  1231. )
  1232. def test_unbound_form(self):
  1233. msg = (
  1234. "The form <TestForm bound=False, valid=Unknown, fields=(field)> is not "
  1235. "bound, it will never have any errors."
  1236. )
  1237. with self.assertRaisesMessage(AssertionError, msg):
  1238. self.assertFormError(TestForm(), "field", [])
  1239. msg_prefix = "Custom prefix"
  1240. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1241. self.assertFormError(TestForm(), "field", [], msg_prefix=msg_prefix)
  1242. def test_empty_errors_invalid_form(self):
  1243. msg = (
  1244. "The errors of field 'field' on form <TestForm bound=True, valid=False, "
  1245. "fields=(field)> don't match."
  1246. )
  1247. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1248. self.assertFormError(TestForm.invalid(), "field", [])
  1249. self.assertIn("['invalid value'] != []", str(ctx.exception))
  1250. def test_non_field_errors(self):
  1251. self.assertFormError(TestForm.invalid(nonfield=True), None, "non-field error")
  1252. def test_different_non_field_errors(self):
  1253. msg = (
  1254. "The non-field errors of form <TestForm bound=True, valid=False, "
  1255. "fields=(field)> don't match."
  1256. )
  1257. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1258. self.assertFormError(
  1259. TestForm.invalid(nonfield=True), None, "other non-field error"
  1260. )
  1261. self.assertIn(
  1262. "['non-field error'] != ['other non-field error']", str(ctx.exception)
  1263. )
  1264. msg_prefix = "Custom prefix"
  1265. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1266. self.assertFormError(
  1267. TestForm.invalid(nonfield=True),
  1268. None,
  1269. "other non-field error",
  1270. msg_prefix=msg_prefix,
  1271. )
  1272. class AssertFormSetErrorTests(SimpleTestCase):
  1273. def test_single_error(self):
  1274. self.assertFormSetError(TestFormset.invalid(), 0, "field", "invalid value")
  1275. def test_error_list(self):
  1276. self.assertFormSetError(TestFormset.invalid(), 0, "field", ["invalid value"])
  1277. def test_empty_errors_valid_formset(self):
  1278. self.assertFormSetError(TestFormset.valid(), 0, "field", [])
  1279. def test_multiple_forms(self):
  1280. formset = TestFormset(
  1281. {
  1282. "form-TOTAL_FORMS": "2",
  1283. "form-INITIAL_FORMS": "0",
  1284. "form-0-field": "valid",
  1285. "form-1-field": "invalid",
  1286. }
  1287. )
  1288. formset.full_clean()
  1289. self.assertFormSetError(formset, 0, "field", [])
  1290. self.assertFormSetError(formset, 1, "field", ["invalid value"])
  1291. def test_field_not_in_form(self):
  1292. msg = (
  1293. "The form 0 of formset <TestFormset: bound=True valid=False total_forms=1> "
  1294. "does not contain the field 'other_field'."
  1295. )
  1296. with self.assertRaisesMessage(AssertionError, msg):
  1297. self.assertFormSetError(
  1298. TestFormset.invalid(), 0, "other_field", "invalid value"
  1299. )
  1300. msg_prefix = "Custom prefix"
  1301. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1302. self.assertFormSetError(
  1303. TestFormset.invalid(),
  1304. 0,
  1305. "other_field",
  1306. "invalid value",
  1307. msg_prefix=msg_prefix,
  1308. )
  1309. def test_field_with_no_errors(self):
  1310. msg = (
  1311. "The errors of field 'field' on form 0 of formset <TestFormset: bound=True "
  1312. "valid=True total_forms=1> don't match."
  1313. )
  1314. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1315. self.assertFormSetError(TestFormset.valid(), 0, "field", "invalid value")
  1316. self.assertIn("[] != ['invalid value']", str(ctx.exception))
  1317. msg_prefix = "Custom prefix"
  1318. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1319. self.assertFormSetError(
  1320. TestFormset.valid(), 0, "field", "invalid value", msg_prefix=msg_prefix
  1321. )
  1322. def test_field_with_different_error(self):
  1323. msg = (
  1324. "The errors of field 'field' on form 0 of formset <TestFormset: bound=True "
  1325. "valid=False total_forms=1> don't match."
  1326. )
  1327. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1328. self.assertFormSetError(TestFormset.invalid(), 0, "field", "other error")
  1329. self.assertIn("['invalid value'] != ['other error']", str(ctx.exception))
  1330. msg_prefix = "Custom prefix"
  1331. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1332. self.assertFormSetError(
  1333. TestFormset.invalid(), 0, "field", "other error", msg_prefix=msg_prefix
  1334. )
  1335. def test_unbound_formset(self):
  1336. msg = (
  1337. "The formset <TestFormset: bound=False valid=Unknown total_forms=1> is not "
  1338. "bound, it will never have any errors."
  1339. )
  1340. with self.assertRaisesMessage(AssertionError, msg):
  1341. self.assertFormSetError(TestFormset(), 0, "field", [])
  1342. def test_empty_errors_invalid_formset(self):
  1343. msg = (
  1344. "The errors of field 'field' on form 0 of formset <TestFormset: bound=True "
  1345. "valid=False total_forms=1> don't match."
  1346. )
  1347. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1348. self.assertFormSetError(TestFormset.invalid(), 0, "field", [])
  1349. self.assertIn("['invalid value'] != []", str(ctx.exception))
  1350. def test_non_field_errors(self):
  1351. self.assertFormSetError(
  1352. TestFormset.invalid(nonfield=True), 0, None, "non-field error"
  1353. )
  1354. def test_different_non_field_errors(self):
  1355. msg = (
  1356. "The non-field errors of form 0 of formset <TestFormset: bound=True "
  1357. "valid=False total_forms=1> don't match."
  1358. )
  1359. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1360. self.assertFormSetError(
  1361. TestFormset.invalid(nonfield=True), 0, None, "other non-field error"
  1362. )
  1363. self.assertIn(
  1364. "['non-field error'] != ['other non-field error']", str(ctx.exception)
  1365. )
  1366. msg_prefix = "Custom prefix"
  1367. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1368. self.assertFormSetError(
  1369. TestFormset.invalid(nonfield=True),
  1370. 0,
  1371. None,
  1372. "other non-field error",
  1373. msg_prefix=msg_prefix,
  1374. )
  1375. def test_no_non_field_errors(self):
  1376. msg = (
  1377. "The non-field errors of form 0 of formset <TestFormset: bound=True "
  1378. "valid=False total_forms=1> don't match."
  1379. )
  1380. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1381. self.assertFormSetError(TestFormset.invalid(), 0, None, "non-field error")
  1382. self.assertIn("[] != ['non-field error']", str(ctx.exception))
  1383. msg_prefix = "Custom prefix"
  1384. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1385. self.assertFormSetError(
  1386. TestFormset.invalid(), 0, None, "non-field error", msg_prefix=msg_prefix
  1387. )
  1388. def test_non_form_errors(self):
  1389. self.assertFormSetError(TestFormset.invalid(nonform=True), None, None, "error")
  1390. def test_different_non_form_errors(self):
  1391. msg = (
  1392. "The non-form errors of formset <TestFormset: bound=True valid=False "
  1393. "total_forms=0> don't match."
  1394. )
  1395. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1396. self.assertFormSetError(
  1397. TestFormset.invalid(nonform=True), None, None, "other error"
  1398. )
  1399. self.assertIn("['error'] != ['other error']", str(ctx.exception))
  1400. msg_prefix = "Custom prefix"
  1401. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1402. self.assertFormSetError(
  1403. TestFormset.invalid(nonform=True),
  1404. None,
  1405. None,
  1406. "other error",
  1407. msg_prefix=msg_prefix,
  1408. )
  1409. def test_no_non_form_errors(self):
  1410. msg = (
  1411. "The non-form errors of formset <TestFormset: bound=True valid=False "
  1412. "total_forms=1> don't match."
  1413. )
  1414. with self.assertRaisesMessage(AssertionError, msg) as ctx:
  1415. self.assertFormSetError(TestFormset.invalid(), None, None, "error")
  1416. self.assertIn("[] != ['error']", str(ctx.exception))
  1417. msg_prefix = "Custom prefix"
  1418. with self.assertRaisesMessage(AssertionError, f"{msg_prefix}: {msg}"):
  1419. self.assertFormSetError(
  1420. TestFormset.invalid(),
  1421. None,
  1422. None,
  1423. "error",
  1424. msg_prefix=msg_prefix,
  1425. )
  1426. def test_non_form_errors_with_field(self):
  1427. msg = "You must use field=None with form_index=None."
  1428. with self.assertRaisesMessage(ValueError, msg):
  1429. self.assertFormSetError(
  1430. TestFormset.invalid(nonform=True), None, "field", "error"
  1431. )
  1432. def test_form_index_too_big(self):
  1433. msg = (
  1434. "The formset <TestFormset: bound=True valid=False total_forms=1> only has "
  1435. "1 form."
  1436. )
  1437. with self.assertRaisesMessage(AssertionError, msg):
  1438. self.assertFormSetError(TestFormset.invalid(), 2, "field", "error")
  1439. def test_form_index_too_big_plural(self):
  1440. formset = TestFormset(
  1441. {
  1442. "form-TOTAL_FORMS": "2",
  1443. "form-INITIAL_FORMS": "0",
  1444. "form-0-field": "valid",
  1445. "form-1-field": "valid",
  1446. }
  1447. )
  1448. formset.full_clean()
  1449. msg = (
  1450. "The formset <TestFormset: bound=True valid=True total_forms=2> only has 2 "
  1451. "forms."
  1452. )
  1453. with self.assertRaisesMessage(AssertionError, msg):
  1454. self.assertFormSetError(formset, 2, "field", "error")
  1455. class FirstUrls:
  1456. urlpatterns = [path("first/", empty_response, name="first")]
  1457. class SecondUrls:
  1458. urlpatterns = [path("second/", empty_response, name="second")]
  1459. class SetupTestEnvironmentTests(SimpleTestCase):
  1460. def test_setup_test_environment_calling_more_than_once(self):
  1461. with self.assertRaisesMessage(
  1462. RuntimeError, "setup_test_environment() was already called"
  1463. ):
  1464. setup_test_environment()
  1465. def test_allowed_hosts(self):
  1466. for type_ in (list, tuple):
  1467. with self.subTest(type_=type_):
  1468. allowed_hosts = type_("*")
  1469. with mock.patch("django.test.utils._TestState") as x:
  1470. del x.saved_data
  1471. with self.settings(ALLOWED_HOSTS=allowed_hosts):
  1472. setup_test_environment()
  1473. self.assertEqual(settings.ALLOWED_HOSTS, ["*", "testserver"])
  1474. class OverrideSettingsTests(SimpleTestCase):
  1475. # #21518 -- If neither override_settings nor a setting_changed receiver
  1476. # clears the URL cache between tests, then one of test_first or
  1477. # test_second will fail.
  1478. @override_settings(ROOT_URLCONF=FirstUrls)
  1479. def test_urlconf_first(self):
  1480. reverse("first")
  1481. @override_settings(ROOT_URLCONF=SecondUrls)
  1482. def test_urlconf_second(self):
  1483. reverse("second")
  1484. def test_urlconf_cache(self):
  1485. with self.assertRaises(NoReverseMatch):
  1486. reverse("first")
  1487. with self.assertRaises(NoReverseMatch):
  1488. reverse("second")
  1489. with override_settings(ROOT_URLCONF=FirstUrls):
  1490. self.client.get(reverse("first"))
  1491. with self.assertRaises(NoReverseMatch):
  1492. reverse("second")
  1493. with override_settings(ROOT_URLCONF=SecondUrls):
  1494. with self.assertRaises(NoReverseMatch):
  1495. reverse("first")
  1496. self.client.get(reverse("second"))
  1497. self.client.get(reverse("first"))
  1498. with self.assertRaises(NoReverseMatch):
  1499. reverse("second")
  1500. with self.assertRaises(NoReverseMatch):
  1501. reverse("first")
  1502. with self.assertRaises(NoReverseMatch):
  1503. reverse("second")
  1504. def test_override_media_root(self):
  1505. """
  1506. Overriding the MEDIA_ROOT setting should be reflected in the
  1507. base_location attribute of django.core.files.storage.default_storage.
  1508. """
  1509. self.assertEqual(default_storage.base_location, "")
  1510. with self.settings(MEDIA_ROOT="test_value"):
  1511. self.assertEqual(default_storage.base_location, "test_value")
  1512. def test_override_media_url(self):
  1513. """
  1514. Overriding the MEDIA_URL setting should be reflected in the
  1515. base_url attribute of django.core.files.storage.default_storage.
  1516. """
  1517. self.assertEqual(default_storage.base_location, "")
  1518. with self.settings(MEDIA_URL="/test_value/"):
  1519. self.assertEqual(default_storage.base_url, "/test_value/")
  1520. def test_override_file_upload_permissions(self):
  1521. """
  1522. Overriding the FILE_UPLOAD_PERMISSIONS setting should be reflected in
  1523. the file_permissions_mode attribute of
  1524. django.core.files.storage.default_storage.
  1525. """
  1526. self.assertEqual(default_storage.file_permissions_mode, 0o644)
  1527. with self.settings(FILE_UPLOAD_PERMISSIONS=0o777):
  1528. self.assertEqual(default_storage.file_permissions_mode, 0o777)
  1529. def test_override_file_upload_directory_permissions(self):
  1530. """
  1531. Overriding the FILE_UPLOAD_DIRECTORY_PERMISSIONS setting should be
  1532. reflected in the directory_permissions_mode attribute of
  1533. django.core.files.storage.default_storage.
  1534. """
  1535. self.assertIsNone(default_storage.directory_permissions_mode)
  1536. with self.settings(FILE_UPLOAD_DIRECTORY_PERMISSIONS=0o777):
  1537. self.assertEqual(default_storage.directory_permissions_mode, 0o777)
  1538. def test_override_database_routers(self):
  1539. """
  1540. Overriding DATABASE_ROUTERS should update the base router.
  1541. """
  1542. test_routers = [object()]
  1543. with self.settings(DATABASE_ROUTERS=test_routers):
  1544. self.assertEqual(router.routers, test_routers)
  1545. def test_override_static_url(self):
  1546. """
  1547. Overriding the STATIC_URL setting should be reflected in the
  1548. base_url attribute of
  1549. django.contrib.staticfiles.storage.staticfiles_storage.
  1550. """
  1551. with self.settings(STATIC_URL="/test/"):
  1552. self.assertEqual(staticfiles_storage.base_url, "/test/")
  1553. def test_override_static_root(self):
  1554. """
  1555. Overriding the STATIC_ROOT setting should be reflected in the
  1556. location attribute of
  1557. django.contrib.staticfiles.storage.staticfiles_storage.
  1558. """
  1559. with self.settings(STATIC_ROOT="/tmp/test"):
  1560. self.assertEqual(staticfiles_storage.location, os.path.abspath("/tmp/test"))
  1561. def test_override_staticfiles_storage(self):
  1562. """
  1563. Overriding the STORAGES setting should be reflected in
  1564. the value of django.contrib.staticfiles.storage.staticfiles_storage.
  1565. """
  1566. new_class = "ManifestStaticFilesStorage"
  1567. new_storage = "django.contrib.staticfiles.storage." + new_class
  1568. with self.settings(
  1569. STORAGES={STATICFILES_STORAGE_ALIAS: {"BACKEND": new_storage}}
  1570. ):
  1571. self.assertEqual(staticfiles_storage.__class__.__name__, new_class)
  1572. def test_override_staticfiles_finders(self):
  1573. """
  1574. Overriding the STATICFILES_FINDERS setting should be reflected in
  1575. the return value of django.contrib.staticfiles.finders.get_finders.
  1576. """
  1577. current = get_finders()
  1578. self.assertGreater(len(list(current)), 1)
  1579. finders = ["django.contrib.staticfiles.finders.FileSystemFinder"]
  1580. with self.settings(STATICFILES_FINDERS=finders):
  1581. self.assertEqual(len(list(get_finders())), len(finders))
  1582. def test_override_staticfiles_dirs(self):
  1583. """
  1584. Overriding the STATICFILES_DIRS setting should be reflected in
  1585. the locations attribute of the
  1586. django.contrib.staticfiles.finders.FileSystemFinder instance.
  1587. """
  1588. finder = get_finder("django.contrib.staticfiles.finders.FileSystemFinder")
  1589. test_path = "/tmp/test"
  1590. expected_location = ("", test_path)
  1591. self.assertNotIn(expected_location, finder.locations)
  1592. with self.settings(STATICFILES_DIRS=[test_path]):
  1593. finder = get_finder("django.contrib.staticfiles.finders.FileSystemFinder")
  1594. self.assertIn(expected_location, finder.locations)
  1595. @skipUnlessDBFeature("supports_transactions")
  1596. class TestBadSetUpTestData(TestCase):
  1597. """
  1598. An exception in setUpTestData() shouldn't leak a transaction which would
  1599. cascade across the rest of the test suite.
  1600. """
  1601. class MyException(Exception):
  1602. pass
  1603. @classmethod
  1604. def setUpClass(cls):
  1605. try:
  1606. super().setUpClass()
  1607. except cls.MyException:
  1608. cls._in_atomic_block = connection.in_atomic_block
  1609. @classmethod
  1610. def tearDownClass(Cls):
  1611. # override to avoid a second cls._rollback_atomics() which would fail.
  1612. # Normal setUpClass() methods won't have exception handling so this
  1613. # method wouldn't typically be run.
  1614. pass
  1615. @classmethod
  1616. def setUpTestData(cls):
  1617. # Simulate a broken setUpTestData() method.
  1618. raise cls.MyException()
  1619. def test_failure_in_setUpTestData_should_rollback_transaction(self):
  1620. # setUpTestData() should call _rollback_atomics() so that the
  1621. # transaction doesn't leak.
  1622. self.assertFalse(self._in_atomic_block)
  1623. @skipUnlessDBFeature("supports_transactions")
  1624. class CaptureOnCommitCallbacksTests(TestCase):
  1625. databases = {"default", "other"}
  1626. callback_called = False
  1627. def enqueue_callback(self, using="default"):
  1628. def hook():
  1629. self.callback_called = True
  1630. transaction.on_commit(hook, using=using)
  1631. def test_no_arguments(self):
  1632. with self.captureOnCommitCallbacks() as callbacks:
  1633. self.enqueue_callback()
  1634. self.assertEqual(len(callbacks), 1)
  1635. self.assertIs(self.callback_called, False)
  1636. callbacks[0]()
  1637. self.assertIs(self.callback_called, True)
  1638. def test_using(self):
  1639. with self.captureOnCommitCallbacks(using="other") as callbacks:
  1640. self.enqueue_callback(using="other")
  1641. self.assertEqual(len(callbacks), 1)
  1642. self.assertIs(self.callback_called, False)
  1643. callbacks[0]()
  1644. self.assertIs(self.callback_called, True)
  1645. def test_different_using(self):
  1646. with self.captureOnCommitCallbacks(using="default") as callbacks:
  1647. self.enqueue_callback(using="other")
  1648. self.assertEqual(callbacks, [])
  1649. def test_execute(self):
  1650. with self.captureOnCommitCallbacks(execute=True) as callbacks:
  1651. self.enqueue_callback()
  1652. self.assertEqual(len(callbacks), 1)
  1653. self.assertIs(self.callback_called, True)
  1654. def test_pre_callback(self):
  1655. def pre_hook():
  1656. pass
  1657. transaction.on_commit(pre_hook, using="default")
  1658. with self.captureOnCommitCallbacks() as callbacks:
  1659. self.enqueue_callback()
  1660. self.assertEqual(len(callbacks), 1)
  1661. self.assertNotEqual(callbacks[0], pre_hook)
  1662. def test_with_rolled_back_savepoint(self):
  1663. with self.captureOnCommitCallbacks() as callbacks:
  1664. try:
  1665. with transaction.atomic():
  1666. self.enqueue_callback()
  1667. raise IntegrityError
  1668. except IntegrityError:
  1669. # Inner transaction.atomic() has been rolled back.
  1670. pass
  1671. self.assertEqual(callbacks, [])
  1672. def test_execute_recursive(self):
  1673. with self.captureOnCommitCallbacks(execute=True) as callbacks:
  1674. transaction.on_commit(self.enqueue_callback)
  1675. self.assertEqual(len(callbacks), 2)
  1676. self.assertIs(self.callback_called, True)
  1677. def test_execute_tree(self):
  1678. """
  1679. A visualisation of the callback tree tested. Each node is expected to
  1680. be visited only once:
  1681. └─branch_1
  1682. ├─branch_2
  1683. │ ├─leaf_1
  1684. │ └─leaf_2
  1685. └─leaf_3
  1686. """
  1687. branch_1_call_counter = 0
  1688. branch_2_call_counter = 0
  1689. leaf_1_call_counter = 0
  1690. leaf_2_call_counter = 0
  1691. leaf_3_call_counter = 0
  1692. def leaf_1():
  1693. nonlocal leaf_1_call_counter
  1694. leaf_1_call_counter += 1
  1695. def leaf_2():
  1696. nonlocal leaf_2_call_counter
  1697. leaf_2_call_counter += 1
  1698. def leaf_3():
  1699. nonlocal leaf_3_call_counter
  1700. leaf_3_call_counter += 1
  1701. def branch_1():
  1702. nonlocal branch_1_call_counter
  1703. branch_1_call_counter += 1
  1704. transaction.on_commit(branch_2)
  1705. transaction.on_commit(leaf_3)
  1706. def branch_2():
  1707. nonlocal branch_2_call_counter
  1708. branch_2_call_counter += 1
  1709. transaction.on_commit(leaf_1)
  1710. transaction.on_commit(leaf_2)
  1711. with self.captureOnCommitCallbacks(execute=True) as callbacks:
  1712. transaction.on_commit(branch_1)
  1713. self.assertEqual(branch_1_call_counter, 1)
  1714. self.assertEqual(branch_2_call_counter, 1)
  1715. self.assertEqual(leaf_1_call_counter, 1)
  1716. self.assertEqual(leaf_2_call_counter, 1)
  1717. self.assertEqual(leaf_3_call_counter, 1)
  1718. self.assertEqual(callbacks, [branch_1, branch_2, leaf_3, leaf_1, leaf_2])
  1719. def test_execute_robust(self):
  1720. class MyException(Exception):
  1721. pass
  1722. def hook():
  1723. self.callback_called = True
  1724. raise MyException("robust callback")
  1725. with self.assertLogs("django.test", "ERROR") as cm:
  1726. with self.captureOnCommitCallbacks(execute=True) as callbacks:
  1727. transaction.on_commit(hook, robust=True)
  1728. self.assertEqual(len(callbacks), 1)
  1729. self.assertIs(self.callback_called, True)
  1730. log_record = cm.records[0]
  1731. self.assertEqual(
  1732. log_record.getMessage(),
  1733. "Error calling CaptureOnCommitCallbacksTests.test_execute_robust.<locals>."
  1734. "hook in on_commit() (robust callback).",
  1735. )
  1736. self.assertIsNotNone(log_record.exc_info)
  1737. raised_exception = log_record.exc_info[1]
  1738. self.assertIsInstance(raised_exception, MyException)
  1739. self.assertEqual(str(raised_exception), "robust callback")
  1740. class DisallowedDatabaseQueriesTests(SimpleTestCase):
  1741. def test_disallowed_database_connections(self):
  1742. expected_message = (
  1743. "Database connections to 'default' are not allowed in SimpleTestCase "
  1744. "subclasses. Either subclass TestCase or TransactionTestCase to "
  1745. "ensure proper test isolation or add 'default' to "
  1746. "test_utils.tests.DisallowedDatabaseQueriesTests.databases to "
  1747. "silence this failure."
  1748. )
  1749. with self.assertRaisesMessage(DatabaseOperationForbidden, expected_message):
  1750. connection.connect()
  1751. with self.assertRaisesMessage(DatabaseOperationForbidden, expected_message):
  1752. connection.temporary_connection()
  1753. def test_disallowed_database_queries(self):
  1754. expected_message = (
  1755. "Database queries to 'default' are not allowed in SimpleTestCase "
  1756. "subclasses. Either subclass TestCase or TransactionTestCase to "
  1757. "ensure proper test isolation or add 'default' to "
  1758. "test_utils.tests.DisallowedDatabaseQueriesTests.databases to "
  1759. "silence this failure."
  1760. )
  1761. with self.assertRaisesMessage(DatabaseOperationForbidden, expected_message):
  1762. Car.objects.first()
  1763. def test_disallowed_database_chunked_cursor_queries(self):
  1764. expected_message = (
  1765. "Database queries to 'default' are not allowed in SimpleTestCase "
  1766. "subclasses. Either subclass TestCase or TransactionTestCase to "
  1767. "ensure proper test isolation or add 'default' to "
  1768. "test_utils.tests.DisallowedDatabaseQueriesTests.databases to "
  1769. "silence this failure."
  1770. )
  1771. with self.assertRaisesMessage(DatabaseOperationForbidden, expected_message):
  1772. next(Car.objects.iterator())
  1773. class AllowedDatabaseQueriesTests(SimpleTestCase):
  1774. databases = {"default"}
  1775. def test_allowed_database_queries(self):
  1776. Car.objects.first()
  1777. def test_allowed_database_chunked_cursor_queries(self):
  1778. next(Car.objects.iterator(), None)
  1779. class DatabaseAliasTests(SimpleTestCase):
  1780. def setUp(self):
  1781. self.addCleanup(setattr, self.__class__, "databases", self.databases)
  1782. def test_no_close_match(self):
  1783. self.__class__.databases = {"void"}
  1784. message = (
  1785. "test_utils.tests.DatabaseAliasTests.databases refers to 'void' which is "
  1786. "not defined in settings.DATABASES."
  1787. )
  1788. with self.assertRaisesMessage(ImproperlyConfigured, message):
  1789. self._validate_databases()
  1790. def test_close_match(self):
  1791. self.__class__.databases = {"defualt"}
  1792. message = (
  1793. "test_utils.tests.DatabaseAliasTests.databases refers to 'defualt' which "
  1794. "is not defined in settings.DATABASES. Did you mean 'default'?"
  1795. )
  1796. with self.assertRaisesMessage(ImproperlyConfigured, message):
  1797. self._validate_databases()
  1798. def test_match(self):
  1799. self.__class__.databases = {"default", "other"}
  1800. self.assertEqual(self._validate_databases(), frozenset({"default", "other"}))
  1801. def test_all(self):
  1802. self.__class__.databases = "__all__"
  1803. self.assertEqual(self._validate_databases(), frozenset(connections))
  1804. @isolate_apps("test_utils", attr_name="class_apps")
  1805. class IsolatedAppsTests(SimpleTestCase):
  1806. def test_installed_apps(self):
  1807. self.assertEqual(
  1808. [app_config.label for app_config in self.class_apps.get_app_configs()],
  1809. ["test_utils"],
  1810. )
  1811. def test_class_decoration(self):
  1812. class ClassDecoration(models.Model):
  1813. pass
  1814. self.assertEqual(ClassDecoration._meta.apps, self.class_apps)
  1815. @isolate_apps("test_utils", kwarg_name="method_apps")
  1816. def test_method_decoration(self, method_apps):
  1817. class MethodDecoration(models.Model):
  1818. pass
  1819. self.assertEqual(MethodDecoration._meta.apps, method_apps)
  1820. def test_context_manager(self):
  1821. with isolate_apps("test_utils") as context_apps:
  1822. class ContextManager(models.Model):
  1823. pass
  1824. self.assertEqual(ContextManager._meta.apps, context_apps)
  1825. @isolate_apps("test_utils", kwarg_name="method_apps")
  1826. def test_nested(self, method_apps):
  1827. class MethodDecoration(models.Model):
  1828. pass
  1829. with isolate_apps("test_utils") as context_apps:
  1830. class ContextManager(models.Model):
  1831. pass
  1832. with isolate_apps("test_utils") as nested_context_apps:
  1833. class NestedContextManager(models.Model):
  1834. pass
  1835. self.assertEqual(MethodDecoration._meta.apps, method_apps)
  1836. self.assertEqual(ContextManager._meta.apps, context_apps)
  1837. self.assertEqual(NestedContextManager._meta.apps, nested_context_apps)
  1838. class DoNothingDecorator(TestContextDecorator):
  1839. def enable(self):
  1840. pass
  1841. def disable(self):
  1842. pass
  1843. class TestContextDecoratorTests(SimpleTestCase):
  1844. @mock.patch.object(DoNothingDecorator, "disable")
  1845. def test_exception_in_setup(self, mock_disable):
  1846. """An exception is setUp() is reraised after disable() is called."""
  1847. class ExceptionInSetUp(unittest.TestCase):
  1848. def setUp(self):
  1849. raise NotImplementedError("reraised")
  1850. decorator = DoNothingDecorator()
  1851. decorated_test_class = decorator.__call__(ExceptionInSetUp)()
  1852. self.assertFalse(mock_disable.called)
  1853. with self.assertRaisesMessage(NotImplementedError, "reraised"):
  1854. decorated_test_class.setUp()
  1855. decorated_test_class.doCleanups()
  1856. self.assertTrue(mock_disable.called)
  1857. def test_cleanups_run_after_tearDown(self):
  1858. calls = []
  1859. class SaveCallsDecorator(TestContextDecorator):
  1860. def enable(self):
  1861. calls.append("enable")
  1862. def disable(self):
  1863. calls.append("disable")
  1864. class AddCleanupInSetUp(unittest.TestCase):
  1865. def setUp(self):
  1866. calls.append("setUp")
  1867. self.addCleanup(lambda: calls.append("cleanup"))
  1868. decorator = SaveCallsDecorator()
  1869. decorated_test_class = decorator.__call__(AddCleanupInSetUp)()
  1870. decorated_test_class.setUp()
  1871. decorated_test_class.tearDown()
  1872. decorated_test_class.doCleanups()
  1873. self.assertEqual(calls, ["enable", "setUp", "cleanup", "disable"])