Browse Source

Fixed #30657 -- Allowed customizing Field's descriptors with a descriptor_class attribute.

Allows model fields to override the descriptor class used on the model
instance attribute.
Jon Dufresne 5 years ago
parent
commit
5ed20b3aa3

+ 3 - 1
django/db/models/fields/__init__.py

@@ -123,6 +123,8 @@ class Field(RegisterLookupMixin):
     one_to_one = None
     related_model = None
 
+    descriptor_class = DeferredAttribute
+
     # Generic field type description, usually overridden by subclasses
     def _description(self):
         return _('Field of type: %(field_type)s') % {
@@ -738,7 +740,7 @@ class Field(RegisterLookupMixin):
             # if you have a classmethod and a field with the same name, then
             # such fields can't be deferred (we don't have a check for this).
             if not getattr(cls, self.attname, None):
-                setattr(cls, self.attname, DeferredAttribute(self))
+                setattr(cls, self.attname, self.descriptor_class(self))
         if self.choices is not None:
             setattr(cls, 'get_%s_display' % self.name,
                     partialmethod(cls._get_FIELD_display, field=self))

+ 10 - 0
docs/ref/models/fields.txt

@@ -1793,6 +1793,16 @@ Field API reference
 
         where the arguments are interpolated from the field's ``__dict__``.
 
+    .. attribute:: descriptor_class
+
+        .. versionadded:: 3.0
+
+        A class implementing the :py:ref:`descriptor protocol <descriptors>`
+        that is instantiated and assigned to the model instance attribute. The
+        constructor must accept a single argument, the ``Field`` instance.
+        Overriding this class attribute allows for customizing the get and set
+        behavior.
+
     To map a ``Field`` to a database-specific type, Django exposes several
     methods:
 

+ 4 - 0
docs/releases/3.0.txt

@@ -283,6 +283,10 @@ Models
   :class:`~django.db.models.Index` now support app label and class
   interpolation using the ``'%(app_label)s'`` and ``'%(class)s'`` placeholders.
 
+* The new :attr:`.Field.descriptor_class` attribute allows model fields to
+  customize the get and set behavior by overriding their
+  :py:ref:`descriptors <descriptors>`.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 

+ 20 - 0
tests/field_subclassing/fields.py

@@ -1,6 +1,26 @@
 from django.db import models
+from django.db.models.query_utils import DeferredAttribute
 
 
 class CustomTypedField(models.TextField):
     def db_type(self, connection):
         return 'custom_field'
+
+
+class CustomDeferredAttribute(DeferredAttribute):
+    def __get__(self, instance, cls=None):
+        self._count_call(instance, 'get')
+        return super().__get__(instance, cls)
+
+    def __set__(self, instance, value):
+        self._count_call(instance, 'set')
+        instance.__dict__[self.field.attname] = value
+
+    def _count_call(self, instance, get_or_set):
+        count_attr = '_%s_%s_count' % (self.field.attname, get_or_set)
+        count = getattr(instance, count_attr, 0)
+        setattr(instance, count_attr, count + 1)
+
+
+class CustomDescriptorField(models.CharField):
+    descriptor_class = CustomDeferredAttribute

+ 25 - 2
tests/field_subclassing/tests.py

@@ -1,7 +1,7 @@
-from django.db import connection
+from django.db import connection, models
 from django.test import SimpleTestCase
 
-from .fields import CustomTypedField
+from .fields import CustomDescriptorField, CustomTypedField
 
 
 class TestDbType(SimpleTestCase):
@@ -9,3 +9,26 @@ class TestDbType(SimpleTestCase):
     def test_db_parameters_respects_db_type(self):
         f = CustomTypedField()
         self.assertEqual(f.db_parameters(connection)['type'], 'custom_field')
+
+
+class DescriptorClassTest(SimpleTestCase):
+    def test_descriptor_class(self):
+        class CustomDescriptorModel(models.Model):
+            name = CustomDescriptorField(max_length=32)
+
+        m = CustomDescriptorModel()
+        self.assertFalse(hasattr(m, '_name_get_count'))
+        # The field is set to its default in the model constructor.
+        self.assertEqual(m._name_set_count, 1)
+        m.name = 'foo'
+        self.assertFalse(hasattr(m, '_name_get_count'))
+        self.assertEqual(m._name_set_count, 2)
+        self.assertEqual(m.name, 'foo')
+        self.assertEqual(m._name_get_count, 1)
+        self.assertEqual(m._name_set_count, 2)
+        m.name = 'bar'
+        self.assertEqual(m._name_get_count, 1)
+        self.assertEqual(m._name_set_count, 3)
+        self.assertEqual(m.name, 'bar')
+        self.assertEqual(m._name_get_count, 2)
+        self.assertEqual(m._name_set_count, 3)