2
0

test_generatedfield.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. from django.core.exceptions import FieldError
  2. from django.db import IntegrityError, connection
  3. from django.db.models import F, GeneratedField, IntegerField
  4. from django.db.models.functions import Lower
  5. from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
  6. from .models import (
  7. GeneratedModel,
  8. GeneratedModelNull,
  9. GeneratedModelNullVirtual,
  10. GeneratedModelOutputField,
  11. GeneratedModelOutputFieldVirtual,
  12. GeneratedModelParams,
  13. GeneratedModelParamsVirtual,
  14. GeneratedModelVirtual,
  15. )
  16. class BaseGeneratedFieldTests(SimpleTestCase):
  17. def test_editable_unsupported(self):
  18. with self.assertRaisesMessage(ValueError, "GeneratedField cannot be editable."):
  19. GeneratedField(expression=Lower("name"), editable=True, db_persist=False)
  20. def test_blank_unsupported(self):
  21. with self.assertRaisesMessage(ValueError, "GeneratedField must be blank."):
  22. GeneratedField(expression=Lower("name"), blank=False, db_persist=False)
  23. def test_default_unsupported(self):
  24. msg = "GeneratedField cannot have a default."
  25. with self.assertRaisesMessage(ValueError, msg):
  26. GeneratedField(expression=Lower("name"), default="", db_persist=False)
  27. def test_database_default_unsupported(self):
  28. msg = "GeneratedField cannot have a database default."
  29. with self.assertRaisesMessage(ValueError, msg):
  30. GeneratedField(expression=Lower("name"), db_default="", db_persist=False)
  31. def test_db_persist_required(self):
  32. msg = "GeneratedField.db_persist must be True or False."
  33. with self.assertRaisesMessage(ValueError, msg):
  34. GeneratedField(expression=Lower("name"))
  35. with self.assertRaisesMessage(ValueError, msg):
  36. GeneratedField(expression=Lower("name"), db_persist=None)
  37. def test_deconstruct(self):
  38. field = GeneratedField(expression=F("a") + F("b"), db_persist=True)
  39. _, path, args, kwargs = field.deconstruct()
  40. self.assertEqual(path, "django.db.models.GeneratedField")
  41. self.assertEqual(args, [])
  42. self.assertEqual(kwargs, {"db_persist": True, "expression": F("a") + F("b")})
  43. class GeneratedFieldTestMixin:
  44. def _refresh_if_needed(self, m):
  45. if not connection.features.can_return_columns_from_insert:
  46. m.refresh_from_db()
  47. return m
  48. def test_unsaved_error(self):
  49. m = self.base_model(a=1, b=2)
  50. msg = "Cannot read a generated field from an unsaved model."
  51. with self.assertRaisesMessage(FieldError, msg):
  52. m.field
  53. def test_create(self):
  54. m = self.base_model.objects.create(a=1, b=2)
  55. m = self._refresh_if_needed(m)
  56. self.assertEqual(m.field, 3)
  57. def test_non_nullable_create(self):
  58. with self.assertRaises(IntegrityError):
  59. self.base_model.objects.create()
  60. def test_save(self):
  61. # Insert.
  62. m = self.base_model(a=2, b=4)
  63. m.save()
  64. m = self._refresh_if_needed(m)
  65. self.assertEqual(m.field, 6)
  66. # Update.
  67. m.a = 4
  68. m.save()
  69. m.refresh_from_db()
  70. self.assertEqual(m.field, 8)
  71. def test_update(self):
  72. m = self.base_model.objects.create(a=1, b=2)
  73. self.base_model.objects.update(b=3)
  74. m = self.base_model.objects.get(pk=m.pk)
  75. self.assertEqual(m.field, 4)
  76. def test_bulk_create(self):
  77. m = self.base_model(a=3, b=4)
  78. (m,) = self.base_model.objects.bulk_create([m])
  79. if not connection.features.can_return_rows_from_bulk_insert:
  80. m = self.base_model.objects.get()
  81. self.assertEqual(m.field, 7)
  82. def test_bulk_update(self):
  83. m = self.base_model.objects.create(a=1, b=2)
  84. m.a = 3
  85. self.base_model.objects.bulk_update([m], fields=["a"])
  86. m = self.base_model.objects.get(pk=m.pk)
  87. self.assertEqual(m.field, 5)
  88. def test_output_field_lookups(self):
  89. """Lookups from the output_field are available on GeneratedFields."""
  90. internal_type = IntegerField().get_internal_type()
  91. min_value, max_value = connection.ops.integer_field_range(internal_type)
  92. if min_value is None:
  93. self.skipTest("Backend doesn't define an integer min value.")
  94. if max_value is None:
  95. self.skipTest("Backend doesn't define an integer max value.")
  96. does_not_exist = self.base_model.DoesNotExist
  97. underflow_value = min_value - 1
  98. with self.assertNumQueries(0), self.assertRaises(does_not_exist):
  99. self.base_model.objects.get(field=underflow_value)
  100. with self.assertNumQueries(0), self.assertRaises(does_not_exist):
  101. self.base_model.objects.get(field__lt=underflow_value)
  102. with self.assertNumQueries(0), self.assertRaises(does_not_exist):
  103. self.base_model.objects.get(field__lte=underflow_value)
  104. overflow_value = max_value + 1
  105. with self.assertNumQueries(0), self.assertRaises(does_not_exist):
  106. self.base_model.objects.get(field=overflow_value)
  107. with self.assertNumQueries(0), self.assertRaises(does_not_exist):
  108. self.base_model.objects.get(field__gt=overflow_value)
  109. with self.assertNumQueries(0), self.assertRaises(does_not_exist):
  110. self.base_model.objects.get(field__gte=overflow_value)
  111. @skipUnlessDBFeature("supports_collation_on_charfield")
  112. def test_output_field(self):
  113. collation = connection.features.test_collations.get("non_default")
  114. if not collation:
  115. self.skipTest("Language collations are not supported.")
  116. m = self.output_field_model.objects.create(name="NAME")
  117. field = m._meta.get_field("lower_name")
  118. db_parameters = field.db_parameters(connection)
  119. self.assertEqual(db_parameters["collation"], collation)
  120. self.assertEqual(db_parameters["type"], field.output_field.db_type(connection))
  121. self.assertNotEqual(
  122. db_parameters["type"],
  123. field._resolved_expression.output_field.db_type(connection),
  124. )
  125. def test_model_with_params(self):
  126. m = self.params_model.objects.create()
  127. m = self._refresh_if_needed(m)
  128. self.assertEqual(m.field, "Constant")
  129. def test_nullable(self):
  130. m1 = self.nullable_model.objects.create()
  131. m1 = self._refresh_if_needed(m1)
  132. none_val = "" if connection.features.interprets_empty_strings_as_nulls else None
  133. self.assertEqual(m1.lower_name, none_val)
  134. m2 = self.nullable_model.objects.create(name="NaMe")
  135. m2 = self._refresh_if_needed(m2)
  136. self.assertEqual(m2.lower_name, "name")
  137. @skipUnlessDBFeature("supports_stored_generated_columns")
  138. class StoredGeneratedFieldTests(GeneratedFieldTestMixin, TestCase):
  139. base_model = GeneratedModel
  140. nullable_model = GeneratedModelNull
  141. output_field_model = GeneratedModelOutputField
  142. params_model = GeneratedModelParams
  143. @skipUnlessDBFeature("supports_virtual_generated_columns")
  144. class VirtualGeneratedFieldTests(GeneratedFieldTestMixin, TestCase):
  145. base_model = GeneratedModelVirtual
  146. nullable_model = GeneratedModelNullVirtual
  147. output_field_model = GeneratedModelOutputFieldVirtual
  148. params_model = GeneratedModelParamsVirtual