tests.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. import json
  2. import unittest
  3. from uuid import UUID
  4. import yaml
  5. from django import forms
  6. from django.core import serializers
  7. from django.core.exceptions import FieldError
  8. from django.db import IntegrityError, connection
  9. from django.db.models import CompositePrimaryKey
  10. from django.forms import modelform_factory
  11. from django.test import TestCase
  12. from .models import Comment, Post, Tenant, User
  13. class CommentForm(forms.ModelForm):
  14. class Meta:
  15. model = Comment
  16. fields = "__all__"
  17. class CompositePKTests(TestCase):
  18. maxDiff = None
  19. @classmethod
  20. def setUpTestData(cls):
  21. cls.tenant = Tenant.objects.create()
  22. cls.user = User.objects.create(
  23. tenant=cls.tenant,
  24. id=1,
  25. email="user0001@example.com",
  26. )
  27. cls.comment = Comment.objects.create(tenant=cls.tenant, id=1, user=cls.user)
  28. @staticmethod
  29. def get_constraints(table):
  30. with connection.cursor() as cursor:
  31. return connection.introspection.get_constraints(cursor, table)
  32. def test_pk_updated_if_field_updated(self):
  33. user = User.objects.get(pk=self.user.pk)
  34. self.assertEqual(user.pk, (self.tenant.id, self.user.id))
  35. self.assertIs(user._is_pk_set(), True)
  36. user.tenant_id = 9831
  37. self.assertEqual(user.pk, (9831, self.user.id))
  38. self.assertIs(user._is_pk_set(), True)
  39. user.id = 4321
  40. self.assertEqual(user.pk, (9831, 4321))
  41. self.assertIs(user._is_pk_set(), True)
  42. user.pk = (9132, 3521)
  43. self.assertEqual(user.tenant_id, 9132)
  44. self.assertEqual(user.id, 3521)
  45. self.assertIs(user._is_pk_set(), True)
  46. user.id = None
  47. self.assertEqual(user.pk, (9132, None))
  48. self.assertEqual(user.tenant_id, 9132)
  49. self.assertIsNone(user.id)
  50. self.assertIs(user._is_pk_set(), False)
  51. def test_hash(self):
  52. self.assertEqual(hash(User(pk=(1, 2))), hash((1, 2)))
  53. self.assertEqual(hash(User(tenant_id=2, id=3)), hash((2, 3)))
  54. msg = "Model instances without primary key value are unhashable"
  55. with self.assertRaisesMessage(TypeError, msg):
  56. hash(User())
  57. with self.assertRaisesMessage(TypeError, msg):
  58. hash(User(tenant_id=1))
  59. with self.assertRaisesMessage(TypeError, msg):
  60. hash(User(id=1))
  61. def test_pk_must_be_list_or_tuple(self):
  62. user = User.objects.get(pk=self.user.pk)
  63. test_cases = [
  64. "foo",
  65. 1000,
  66. 3.14,
  67. True,
  68. False,
  69. ]
  70. for pk in test_cases:
  71. with self.assertRaisesMessage(
  72. ValueError, "'pk' must be a list or a tuple."
  73. ):
  74. user.pk = pk
  75. def test_pk_must_have_2_elements(self):
  76. user = User.objects.get(pk=self.user.pk)
  77. test_cases = [
  78. (),
  79. [],
  80. (1000,),
  81. [1000],
  82. (1, 2, 3),
  83. [1, 2, 3],
  84. ]
  85. for pk in test_cases:
  86. with self.assertRaisesMessage(ValueError, "'pk' must have 2 elements."):
  87. user.pk = pk
  88. def test_composite_pk_in_fields(self):
  89. user_fields = {f.name for f in User._meta.get_fields()}
  90. self.assertEqual(user_fields, {"pk", "tenant", "id", "email", "comments"})
  91. comment_fields = {f.name for f in Comment._meta.get_fields()}
  92. self.assertEqual(
  93. comment_fields,
  94. {"pk", "tenant", "id", "user_id", "user", "text"},
  95. )
  96. def test_pk_field(self):
  97. pk = User._meta.get_field("pk")
  98. self.assertIsInstance(pk, CompositePrimaryKey)
  99. self.assertIs(User._meta.pk, pk)
  100. def test_error_on_user_pk_conflict(self):
  101. with self.assertRaises(IntegrityError):
  102. User.objects.create(tenant=self.tenant, id=self.user.id)
  103. def test_error_on_comment_pk_conflict(self):
  104. with self.assertRaises(IntegrityError):
  105. Comment.objects.create(tenant=self.tenant, id=self.comment.id)
  106. @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific test")
  107. def test_get_constraints_postgresql(self):
  108. user_constraints = self.get_constraints(User._meta.db_table)
  109. user_pk = user_constraints["composite_pk_user_pkey"]
  110. self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
  111. self.assertIs(user_pk["primary_key"], True)
  112. comment_constraints = self.get_constraints(Comment._meta.db_table)
  113. comment_pk = comment_constraints["composite_pk_comment_pkey"]
  114. self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
  115. self.assertIs(comment_pk["primary_key"], True)
  116. @unittest.skipUnless(connection.vendor == "sqlite", "SQLite specific test")
  117. def test_get_constraints_sqlite(self):
  118. user_constraints = self.get_constraints(User._meta.db_table)
  119. user_pk = user_constraints["__primary__"]
  120. self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
  121. self.assertIs(user_pk["primary_key"], True)
  122. comment_constraints = self.get_constraints(Comment._meta.db_table)
  123. comment_pk = comment_constraints["__primary__"]
  124. self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
  125. self.assertIs(comment_pk["primary_key"], True)
  126. @unittest.skipUnless(connection.vendor == "mysql", "MySQL specific test")
  127. def test_get_constraints_mysql(self):
  128. user_constraints = self.get_constraints(User._meta.db_table)
  129. user_pk = user_constraints["PRIMARY"]
  130. self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
  131. self.assertIs(user_pk["primary_key"], True)
  132. comment_constraints = self.get_constraints(Comment._meta.db_table)
  133. comment_pk = comment_constraints["PRIMARY"]
  134. self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
  135. self.assertIs(comment_pk["primary_key"], True)
  136. @unittest.skipUnless(connection.vendor == "oracle", "Oracle specific test")
  137. def test_get_constraints_oracle(self):
  138. user_constraints = self.get_constraints(User._meta.db_table)
  139. user_pk = next(c for c in user_constraints.values() if c["primary_key"])
  140. self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
  141. self.assertEqual(user_pk["primary_key"], 1)
  142. comment_constraints = self.get_constraints(Comment._meta.db_table)
  143. comment_pk = next(c for c in comment_constraints.values() if c["primary_key"])
  144. self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
  145. self.assertEqual(comment_pk["primary_key"], 1)
  146. def test_in_bulk(self):
  147. """
  148. Test the .in_bulk() method of composite_pk models.
  149. """
  150. result = Comment.objects.in_bulk()
  151. self.assertEqual(result, {self.comment.pk: self.comment})
  152. result = Comment.objects.in_bulk([self.comment.pk])
  153. self.assertEqual(result, {self.comment.pk: self.comment})
  154. def test_iterator(self):
  155. """
  156. Test the .iterator() method of composite_pk models.
  157. """
  158. result = list(Comment.objects.iterator())
  159. self.assertEqual(result, [self.comment])
  160. def test_query(self):
  161. users = User.objects.values_list("pk").order_by("pk")
  162. self.assertNotIn('AS "pk"', str(users.query))
  163. def test_only(self):
  164. users = User.objects.only("pk")
  165. self.assertSequenceEqual(users, (self.user,))
  166. user = users[0]
  167. with self.assertNumQueries(0):
  168. self.assertEqual(user.pk, (self.user.tenant_id, self.user.id))
  169. self.assertEqual(user.tenant_id, self.user.tenant_id)
  170. self.assertEqual(user.id, self.user.id)
  171. with self.assertNumQueries(1):
  172. self.assertEqual(user.email, self.user.email)
  173. def test_model_forms(self):
  174. fields = ["tenant", "id", "user_id", "text"]
  175. self.assertEqual(list(CommentForm.base_fields), fields)
  176. form = modelform_factory(Comment, fields="__all__")
  177. self.assertEqual(list(form().fields), fields)
  178. with self.assertRaisesMessage(
  179. FieldError, "Unknown field(s) (pk) specified for Comment"
  180. ):
  181. self.assertIsNone(modelform_factory(Comment, fields=["pk"]))
  182. class CompositePKFixturesTests(TestCase):
  183. fixtures = ["tenant"]
  184. def test_objects(self):
  185. tenant_1, tenant_2, tenant_3 = Tenant.objects.order_by("pk")
  186. self.assertEqual(tenant_1.id, 1)
  187. self.assertEqual(tenant_1.name, "Tenant 1")
  188. self.assertEqual(tenant_2.id, 2)
  189. self.assertEqual(tenant_2.name, "Tenant 2")
  190. self.assertEqual(tenant_3.id, 3)
  191. self.assertEqual(tenant_3.name, "Tenant 3")
  192. user_1, user_2, user_3, user_4 = User.objects.order_by("pk")
  193. self.assertEqual(user_1.id, 1)
  194. self.assertEqual(user_1.tenant_id, 1)
  195. self.assertEqual(user_1.pk, (user_1.tenant_id, user_1.id))
  196. self.assertEqual(user_1.email, "user0001@example.com")
  197. self.assertEqual(user_2.id, 2)
  198. self.assertEqual(user_2.tenant_id, 1)
  199. self.assertEqual(user_2.pk, (user_2.tenant_id, user_2.id))
  200. self.assertEqual(user_2.email, "user0002@example.com")
  201. self.assertEqual(user_3.id, 3)
  202. self.assertEqual(user_3.tenant_id, 2)
  203. self.assertEqual(user_3.pk, (user_3.tenant_id, user_3.id))
  204. self.assertEqual(user_3.email, "user0003@example.com")
  205. self.assertEqual(user_4.id, 4)
  206. self.assertEqual(user_4.tenant_id, 2)
  207. self.assertEqual(user_4.pk, (user_4.tenant_id, user_4.id))
  208. self.assertEqual(user_4.email, "user0004@example.com")
  209. post_1, post_2 = Post.objects.order_by("pk")
  210. self.assertEqual(post_1.id, UUID("11111111-1111-1111-1111-111111111111"))
  211. self.assertEqual(post_1.tenant_id, 2)
  212. self.assertEqual(post_1.pk, (post_1.tenant_id, post_1.id))
  213. self.assertEqual(post_2.id, UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"))
  214. self.assertEqual(post_2.tenant_id, 2)
  215. self.assertEqual(post_2.pk, (post_2.tenant_id, post_2.id))
  216. def test_serialize_user_json(self):
  217. users = User.objects.filter(pk=(1, 1))
  218. result = serializers.serialize("json", users)
  219. self.assertEqual(
  220. json.loads(result),
  221. [
  222. {
  223. "model": "composite_pk.user",
  224. "pk": [1, 1],
  225. "fields": {
  226. "email": "user0001@example.com",
  227. "id": 1,
  228. "tenant": 1,
  229. },
  230. }
  231. ],
  232. )
  233. def test_serialize_user_jsonl(self):
  234. users = User.objects.filter(pk=(1, 2))
  235. result = serializers.serialize("jsonl", users)
  236. self.assertEqual(
  237. json.loads(result),
  238. {
  239. "model": "composite_pk.user",
  240. "pk": [1, 2],
  241. "fields": {
  242. "email": "user0002@example.com",
  243. "id": 2,
  244. "tenant": 1,
  245. },
  246. },
  247. )
  248. def test_serialize_user_yaml(self):
  249. users = User.objects.filter(pk=(2, 3))
  250. result = serializers.serialize("yaml", users)
  251. self.assertEqual(
  252. yaml.safe_load(result),
  253. [
  254. {
  255. "model": "composite_pk.user",
  256. "pk": [2, 3],
  257. "fields": {
  258. "email": "user0003@example.com",
  259. "id": 3,
  260. "tenant": 2,
  261. },
  262. },
  263. ],
  264. )
  265. def test_serialize_user_python(self):
  266. users = User.objects.filter(pk=(2, 4))
  267. result = serializers.serialize("python", users)
  268. self.assertEqual(
  269. result,
  270. [
  271. {
  272. "model": "composite_pk.user",
  273. "pk": [2, 4],
  274. "fields": {
  275. "email": "user0004@example.com",
  276. "id": 4,
  277. "tenant": 2,
  278. },
  279. },
  280. ],
  281. )
  282. def test_serialize_post_uuid(self):
  283. posts = Post.objects.filter(pk=(2, "11111111-1111-1111-1111-111111111111"))
  284. result = serializers.serialize("json", posts)
  285. self.assertEqual(
  286. json.loads(result),
  287. [
  288. {
  289. "model": "composite_pk.post",
  290. "pk": [2, "11111111-1111-1111-1111-111111111111"],
  291. "fields": {
  292. "id": "11111111-1111-1111-1111-111111111111",
  293. "tenant": 2,
  294. },
  295. },
  296. ],
  297. )