tests.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. import json
  2. import unittest
  3. from uuid import UUID
  4. try:
  5. import yaml # NOQA
  6. HAS_YAML = True
  7. except ImportError:
  8. HAS_YAML = False
  9. from django import forms
  10. from django.core import serializers
  11. from django.core.exceptions import FieldError
  12. from django.db import IntegrityError, connection
  13. from django.db.models import CompositePrimaryKey
  14. from django.forms import modelform_factory
  15. from django.test import TestCase
  16. from .models import Comment, Post, Tenant, TimeStamped, User
  17. class CommentForm(forms.ModelForm):
  18. class Meta:
  19. model = Comment
  20. fields = "__all__"
  21. class CompositePKTests(TestCase):
  22. maxDiff = None
  23. @classmethod
  24. def setUpTestData(cls):
  25. cls.tenant = Tenant.objects.create()
  26. cls.user = User.objects.create(
  27. tenant=cls.tenant,
  28. id=1,
  29. email="user0001@example.com",
  30. )
  31. cls.comment = Comment.objects.create(tenant=cls.tenant, id=1, user=cls.user)
  32. @staticmethod
  33. def get_primary_key_columns(table):
  34. with connection.cursor() as cursor:
  35. return connection.introspection.get_primary_key_columns(cursor, table)
  36. def test_pk_updated_if_field_updated(self):
  37. user = User.objects.get(pk=self.user.pk)
  38. self.assertEqual(user.pk, (self.tenant.id, self.user.id))
  39. self.assertIs(user._is_pk_set(), True)
  40. user.tenant_id = 9831
  41. self.assertEqual(user.pk, (9831, self.user.id))
  42. self.assertIs(user._is_pk_set(), True)
  43. user.id = 4321
  44. self.assertEqual(user.pk, (9831, 4321))
  45. self.assertIs(user._is_pk_set(), True)
  46. user.pk = (9132, 3521)
  47. self.assertEqual(user.tenant_id, 9132)
  48. self.assertEqual(user.id, 3521)
  49. self.assertIs(user._is_pk_set(), True)
  50. user.id = None
  51. self.assertEqual(user.pk, (9132, None))
  52. self.assertEqual(user.tenant_id, 9132)
  53. self.assertIsNone(user.id)
  54. self.assertIs(user._is_pk_set(), False)
  55. def test_hash(self):
  56. self.assertEqual(hash(User(pk=(1, 2))), hash((1, 2)))
  57. self.assertEqual(hash(User(tenant_id=2, id=3)), hash((2, 3)))
  58. msg = "Model instances without primary key value are unhashable"
  59. with self.assertRaisesMessage(TypeError, msg):
  60. hash(User())
  61. with self.assertRaisesMessage(TypeError, msg):
  62. hash(User(tenant_id=1))
  63. with self.assertRaisesMessage(TypeError, msg):
  64. hash(User(id=1))
  65. def test_pk_must_be_list_or_tuple(self):
  66. user = User.objects.get(pk=self.user.pk)
  67. test_cases = [
  68. "foo",
  69. 1000,
  70. 3.14,
  71. True,
  72. False,
  73. ]
  74. for pk in test_cases:
  75. with self.assertRaisesMessage(
  76. ValueError, "'pk' must be a list or a tuple."
  77. ):
  78. user.pk = pk
  79. def test_pk_must_have_2_elements(self):
  80. user = User.objects.get(pk=self.user.pk)
  81. test_cases = [
  82. (),
  83. [],
  84. (1000,),
  85. [1000],
  86. (1, 2, 3),
  87. [1, 2, 3],
  88. ]
  89. for pk in test_cases:
  90. with self.assertRaisesMessage(ValueError, "'pk' must have 2 elements."):
  91. user.pk = pk
  92. def test_composite_pk_in_fields(self):
  93. user_fields = {f.name for f in User._meta.get_fields()}
  94. self.assertTrue({"pk", "tenant", "id"}.issubset(user_fields))
  95. comment_fields = {f.name for f in Comment._meta.get_fields()}
  96. self.assertTrue({"pk", "tenant", "id"}.issubset(comment_fields))
  97. def test_pk_field(self):
  98. pk = User._meta.get_field("pk")
  99. self.assertIsInstance(pk, CompositePrimaryKey)
  100. self.assertIs(User._meta.pk, pk)
  101. def test_error_on_user_pk_conflict(self):
  102. with self.assertRaises(IntegrityError):
  103. User.objects.create(tenant=self.tenant, id=self.user.id)
  104. def test_error_on_comment_pk_conflict(self):
  105. with self.assertRaises(IntegrityError):
  106. Comment.objects.create(tenant=self.tenant, id=self.comment.id, user_id=1)
  107. def test_get_primary_key_columns(self):
  108. self.assertEqual(
  109. self.get_primary_key_columns(User._meta.db_table),
  110. ["tenant_id", "id"],
  111. )
  112. self.assertEqual(
  113. self.get_primary_key_columns(Comment._meta.db_table),
  114. ["tenant_id", "comment_id"],
  115. )
  116. def test_in_bulk(self):
  117. """
  118. Test the .in_bulk() method of composite_pk models.
  119. """
  120. result = Comment.objects.in_bulk()
  121. self.assertEqual(result, {self.comment.pk: self.comment})
  122. result = Comment.objects.in_bulk([self.comment.pk])
  123. self.assertEqual(result, {self.comment.pk: self.comment})
  124. def test_iterator(self):
  125. """
  126. Test the .iterator() method of composite_pk models.
  127. """
  128. result = list(Comment.objects.iterator())
  129. self.assertEqual(result, [self.comment])
  130. def test_query(self):
  131. users = User.objects.values_list("pk").order_by("pk")
  132. self.assertNotIn('AS "pk"', str(users.query))
  133. def test_only(self):
  134. users = User.objects.only("pk")
  135. self.assertSequenceEqual(users, (self.user,))
  136. user = users[0]
  137. with self.assertNumQueries(0):
  138. self.assertEqual(user.pk, (self.user.tenant_id, self.user.id))
  139. self.assertEqual(user.tenant_id, self.user.tenant_id)
  140. self.assertEqual(user.id, self.user.id)
  141. with self.assertNumQueries(1):
  142. self.assertEqual(user.email, self.user.email)
  143. def test_model_forms(self):
  144. fields = ["tenant", "id", "user_id", "text", "integer"]
  145. self.assertEqual(list(CommentForm.base_fields), fields)
  146. form = modelform_factory(Comment, fields="__all__")
  147. self.assertEqual(list(form().fields), fields)
  148. with self.assertRaisesMessage(
  149. FieldError, "Unknown field(s) (pk) specified for Comment"
  150. ):
  151. self.assertIsNone(modelform_factory(Comment, fields=["pk"]))
  152. class CompositePKFixturesTests(TestCase):
  153. fixtures = ["tenant"]
  154. def test_objects(self):
  155. tenant_1, tenant_2, tenant_3 = Tenant.objects.order_by("pk")
  156. self.assertEqual(tenant_1.id, 1)
  157. self.assertEqual(tenant_1.name, "Tenant 1")
  158. self.assertEqual(tenant_2.id, 2)
  159. self.assertEqual(tenant_2.name, "Tenant 2")
  160. self.assertEqual(tenant_3.id, 3)
  161. self.assertEqual(tenant_3.name, "Tenant 3")
  162. user_1, user_2, user_3, user_4 = User.objects.order_by("pk")
  163. self.assertEqual(user_1.id, 1)
  164. self.assertEqual(user_1.tenant_id, 1)
  165. self.assertEqual(user_1.pk, (user_1.tenant_id, user_1.id))
  166. self.assertEqual(user_1.email, "user0001@example.com")
  167. self.assertEqual(user_2.id, 2)
  168. self.assertEqual(user_2.tenant_id, 1)
  169. self.assertEqual(user_2.pk, (user_2.tenant_id, user_2.id))
  170. self.assertEqual(user_2.email, "user0002@example.com")
  171. self.assertEqual(user_3.id, 3)
  172. self.assertEqual(user_3.tenant_id, 2)
  173. self.assertEqual(user_3.pk, (user_3.tenant_id, user_3.id))
  174. self.assertEqual(user_3.email, "user0003@example.com")
  175. self.assertEqual(user_4.id, 4)
  176. self.assertEqual(user_4.tenant_id, 2)
  177. self.assertEqual(user_4.pk, (user_4.tenant_id, user_4.id))
  178. self.assertEqual(user_4.email, "user0004@example.com")
  179. post_1, post_2 = Post.objects.order_by("pk")
  180. self.assertEqual(post_1.id, UUID("11111111-1111-1111-1111-111111111111"))
  181. self.assertEqual(post_1.tenant_id, 2)
  182. self.assertEqual(post_1.pk, (post_1.tenant_id, post_1.id))
  183. self.assertEqual(post_2.id, UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"))
  184. self.assertEqual(post_2.tenant_id, 2)
  185. self.assertEqual(post_2.pk, (post_2.tenant_id, post_2.id))
  186. def assert_deserializer(self, format, users, serialized_users):
  187. deserialized_user = list(serializers.deserialize(format, serialized_users))[0]
  188. self.assertEqual(deserialized_user.object.email, users[0].email)
  189. self.assertEqual(deserialized_user.object.id, users[0].id)
  190. self.assertEqual(deserialized_user.object.tenant, users[0].tenant)
  191. self.assertEqual(deserialized_user.object.pk, users[0].pk)
  192. def test_serialize_user_json(self):
  193. users = User.objects.filter(pk=(1, 1))
  194. result = serializers.serialize("json", users)
  195. self.assertEqual(
  196. json.loads(result),
  197. [
  198. {
  199. "model": "composite_pk.user",
  200. "pk": [1, 1],
  201. "fields": {
  202. "email": "user0001@example.com",
  203. "id": 1,
  204. "tenant": 1,
  205. },
  206. }
  207. ],
  208. )
  209. self.assert_deserializer(format="json", users=users, serialized_users=result)
  210. def test_serialize_user_jsonl(self):
  211. users = User.objects.filter(pk=(1, 2))
  212. result = serializers.serialize("jsonl", users)
  213. self.assertEqual(
  214. json.loads(result),
  215. {
  216. "model": "composite_pk.user",
  217. "pk": [1, 2],
  218. "fields": {
  219. "email": "user0002@example.com",
  220. "id": 2,
  221. "tenant": 1,
  222. },
  223. },
  224. )
  225. self.assert_deserializer(format="jsonl", users=users, serialized_users=result)
  226. @unittest.skipUnless(HAS_YAML, "No yaml library detected")
  227. def test_serialize_user_yaml(self):
  228. users = User.objects.filter(pk=(2, 3))
  229. result = serializers.serialize("yaml", users)
  230. self.assertEqual(
  231. yaml.safe_load(result),
  232. [
  233. {
  234. "model": "composite_pk.user",
  235. "pk": [2, 3],
  236. "fields": {
  237. "email": "user0003@example.com",
  238. "id": 3,
  239. "tenant": 2,
  240. },
  241. },
  242. ],
  243. )
  244. self.assert_deserializer(format="yaml", users=users, serialized_users=result)
  245. def test_serialize_user_python(self):
  246. users = User.objects.filter(pk=(2, 4))
  247. result = serializers.serialize("python", users)
  248. self.assertEqual(
  249. result,
  250. [
  251. {
  252. "model": "composite_pk.user",
  253. "pk": [2, 4],
  254. "fields": {
  255. "email": "user0004@example.com",
  256. "id": 4,
  257. "tenant": 2,
  258. },
  259. },
  260. ],
  261. )
  262. self.assert_deserializer(format="python", users=users, serialized_users=result)
  263. def test_serialize_user_xml(self):
  264. users = User.objects.filter(pk=(1, 1))
  265. result = serializers.serialize("xml", users)
  266. self.assertIn('<object model="composite_pk.user" pk=\'["1", "1"]\'>', result)
  267. self.assert_deserializer(format="xml", users=users, serialized_users=result)
  268. def test_serialize_post_uuid(self):
  269. posts = Post.objects.filter(pk=(2, "11111111-1111-1111-1111-111111111111"))
  270. result = serializers.serialize("json", posts)
  271. self.assertEqual(
  272. json.loads(result),
  273. [
  274. {
  275. "model": "composite_pk.post",
  276. "pk": [2, "11111111-1111-1111-1111-111111111111"],
  277. "fields": {
  278. "id": "11111111-1111-1111-1111-111111111111",
  279. "tenant": 2,
  280. },
  281. },
  282. ],
  283. )
  284. def test_serialize_datetime(self):
  285. result = serializers.serialize("json", TimeStamped.objects.all())
  286. self.assertEqual(
  287. json.loads(result),
  288. [
  289. {
  290. "model": "composite_pk.timestamped",
  291. "pk": [1, "2022-01-12T05:55:14.956"],
  292. "fields": {
  293. "id": 1,
  294. "created": "2022-01-12T05:55:14.956",
  295. "text": "",
  296. },
  297. },
  298. ],
  299. )
  300. def test_invalid_pk_extra_field(self):
  301. json = (
  302. '[{"fields": {"email": "user0001@example.com", "id": 1, "tenant": 1}, '
  303. '"pk": [1, 1, "extra"], "model": "composite_pk.user"}]'
  304. )
  305. with self.assertRaises(serializers.base.DeserializationError):
  306. next(serializers.deserialize("json", json))