Browse Source

Fixed #29799 -- Allowed registering lookups per field instances.

Thanks Simon Charette and Mariusz Felisiak for reviews and mentoring
this Google Summer of Code 2022 project.
Allen Jonathan David 2 years ago
parent
commit
cd1afd553f

+ 1 - 1
django/db/models/fields/related.py

@@ -858,7 +858,7 @@ class ForeignObject(RelatedField):
 
     @classmethod
     @functools.lru_cache(maxsize=None)
-    def get_lookups(cls):
+    def get_class_lookups(cls):
         bases = inspect.getmro(cls)
         bases = bases[: bases.index(ForeignObject) + 1]
         class_lookups = [parent.__dict__.get("class_lookups", {}) for parent in bases]

+ 61 - 13
django/db/models/query_utils.py

@@ -188,19 +188,42 @@ class DeferredAttribute:
         return None
 
 
+class class_or_instance_method:
+    """
+    Hook used in RegisterLookupMixin to return partial functions depending on
+    the caller type (instance or class of models.Field).
+    """
+
+    def __init__(self, class_method, instance_method):
+        self.class_method = class_method
+        self.instance_method = instance_method
+
+    def __get__(self, instance, owner):
+        if instance is None:
+            return functools.partial(self.class_method, owner)
+        return functools.partial(self.instance_method, instance)
+
+
 class RegisterLookupMixin:
-    @classmethod
-    def _get_lookup(cls, lookup_name):
-        return cls.get_lookups().get(lookup_name, None)
+    def _get_lookup(self, lookup_name):
+        return self.get_lookups().get(lookup_name, None)
 
-    @classmethod
     @functools.lru_cache(maxsize=None)
-    def get_lookups(cls):
+    def get_class_lookups(cls):
         class_lookups = [
             parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls)
         ]
         return cls.merge_dicts(class_lookups)
 
+    def get_instance_lookups(self):
+        class_lookups = self.get_class_lookups()
+        if instance_lookups := getattr(self, "instance_lookups", None):
+            return {**class_lookups, **instance_lookups}
+        return class_lookups
+
+    get_lookups = class_or_instance_method(get_class_lookups, get_instance_lookups)
+    get_class_lookups = classmethod(get_class_lookups)
+
     def get_lookup(self, lookup_name):
         from django.db.models.lookups import Lookup
 
@@ -233,22 +256,33 @@ class RegisterLookupMixin:
         return merged
 
     @classmethod
-    def _clear_cached_lookups(cls):
+    def _clear_cached_class_lookups(cls):
         for subclass in subclasses(cls):
-            subclass.get_lookups.cache_clear()
+            subclass.get_class_lookups.cache_clear()
 
-    @classmethod
-    def register_lookup(cls, lookup, lookup_name=None):
+    def register_class_lookup(cls, lookup, lookup_name=None):
         if lookup_name is None:
             lookup_name = lookup.lookup_name
         if "class_lookups" not in cls.__dict__:
             cls.class_lookups = {}
         cls.class_lookups[lookup_name] = lookup
-        cls._clear_cached_lookups()
+        cls._clear_cached_class_lookups()
         return lookup
 
-    @classmethod
-    def _unregister_lookup(cls, lookup, lookup_name=None):
+    def register_instance_lookup(self, lookup, lookup_name=None):
+        if lookup_name is None:
+            lookup_name = lookup.lookup_name
+        if "instance_lookups" not in self.__dict__:
+            self.instance_lookups = {}
+        self.instance_lookups[lookup_name] = lookup
+        return lookup
+
+    register_lookup = class_or_instance_method(
+        register_class_lookup, register_instance_lookup
+    )
+    register_class_lookup = classmethod(register_class_lookup)
+
+    def _unregister_class_lookup(cls, lookup, lookup_name=None):
         """
         Remove given lookup from cls lookups. For use in tests only as it's
         not thread-safe.
@@ -256,7 +290,21 @@ class RegisterLookupMixin:
         if lookup_name is None:
             lookup_name = lookup.lookup_name
         del cls.class_lookups[lookup_name]
-        cls._clear_cached_lookups()
+        cls._clear_cached_class_lookups()
+
+    def _unregister_instance_lookup(self, lookup, lookup_name=None):
+        """
+        Remove given lookup from instance lookups. For use in tests only as
+        it's not thread-safe.
+        """
+        if lookup_name is None:
+            lookup_name = lookup.lookup_name
+        del self.instance_lookups[lookup_name]
+
+    _unregister_lookup = class_or_instance_method(
+        _unregister_class_lookup, _unregister_instance_lookup
+    )
+    _unregister_class_lookup = classmethod(_unregister_class_lookup)
 
 
 def select_related_descend(field, restricted, requested, select_mask, reverse=False):

+ 7 - 2
docs/ref/models/fields.txt

@@ -520,8 +520,13 @@ Registering and fetching lookups
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 ``Field`` implements the :ref:`lookup registration API <lookup-registration-api>`.
-The API can be used to customize which lookups are available for a field class, and
-how lookups are fetched from a field.
+The API can be used to customize which lookups are available for a field class
+and its instances, and how lookups are fetched from a field.
+
+.. versionchanged:: 4.2
+
+    Support for registering lookups on :class:`~django.db.models.Field`
+    instances was added.
 
 .. _model-field-types:
 

+ 29 - 15
docs/ref/models/lookups.txt

@@ -34,7 +34,7 @@ Registration API
 ================
 
 Django uses :class:`~lookups.RegisterLookupMixin` to give a class the interface to
-register lookups on itself. The two prominent examples are
+register lookups on itself or its instances. The two prominent examples are
 :class:`~django.db.models.Field`, the base class of all model fields, and
 :class:`Transform`, the base class of all Django transforms.
 
@@ -44,35 +44,49 @@ register lookups on itself. The two prominent examples are
 
     .. classmethod:: register_lookup(lookup, lookup_name=None)
 
-        Registers a new lookup in the class. For example
-        ``DateField.register_lookup(YearExact)`` will register ``YearExact``
-        lookup on ``DateField``. It overrides a lookup that already exists with
-        the same name. ``lookup_name`` will be used for this lookup if
+        Registers a new lookup in the class or class instance. For example::
+
+            DateField.register_lookup(YearExact)
+            User._meta.get_field('date_joined').register_lookup(MonthExact)
+
+        will register ``YearExact`` lookup on ``DateField`` and ``MonthExact``
+        lookup on the ``User.date_joined`` (you can use :ref:`Field Access API
+        <model-meta-field-api>` to retrieve a single field instance). It
+        overrides a lookup that already exists with the same name. Lookups
+        registered on field instances take precedence over the lookups
+        registered on classes. ``lookup_name`` will be used for this lookup if
         provided, otherwise ``lookup.lookup_name`` will be used.
 
     .. method:: get_lookup(lookup_name)
 
-        Returns the :class:`Lookup` named ``lookup_name`` registered in the class.
-        The default implementation looks recursively on all parent classes
-        and checks if any has a registered lookup named ``lookup_name``, returning
-        the first match.
+        Returns the :class:`Lookup` named ``lookup_name`` registered in the
+        class or class instance depending on what calls it. The default
+        implementation looks recursively on all parent classes and checks if
+        any has a registered lookup named ``lookup_name``, returning the first
+        match. Instance lookups would override any class lookups with the same
+        ``lookup_name``.
 
     .. method:: get_lookups()
 
-        Returns a dictionary of each lookup name registered in the class mapped
-        to the :class:`Lookup` class.
+        Returns a dictionary of each lookup name registered in the class or
+        class instance mapped to the :class:`Lookup` class.
 
     .. method:: get_transform(transform_name)
 
-        Returns a :class:`Transform` named ``transform_name``. The default
-        implementation looks recursively on all parent classes to check if any
-        has the registered transform named ``transform_name``, returning the first
-        match.
+        Returns a :class:`Transform` named ``transform_name`` registered in the
+        class or class instance. The default implementation looks recursively
+        on all parent classes to check if any has the registered transform
+        named ``transform_name``, returning the first match.
 
 For a class to be a lookup, it must follow the :ref:`Query Expression API
 <query-expression>`. :class:`~Lookup` and :class:`~Transform` naturally
 follow this API.
 
+.. versionchanged:: 4.2
+
+    Support for registering lookups on :class:`~django.db.models.Field`
+    instances was added.
+
 .. _query-expression:
 
 The Query Expression API

+ 3 - 0
docs/releases/4.2.txt

@@ -204,6 +204,9 @@ Models
 * :meth:`~.QuerySet.prefetch_related` now supports
   :class:`~django.db.models.Prefetch` objects with sliced querysets.
 
+* :ref:`Registering lookups <lookup-registration-api>` on
+  :class:`~django.db.models.Field` instances is now supported.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 

+ 1 - 0
tests/custom_lookups/models.py

@@ -3,6 +3,7 @@ from django.db import models
 
 class Author(models.Model):
     name = models.CharField(max_length=20)
+    alias = models.CharField(max_length=20)
     age = models.IntegerField(null=True)
     birthdate = models.DateField(null=True)
     average_rating = models.FloatField(null=True)

+ 48 - 5
tests/custom_lookups/tests.py

@@ -330,7 +330,7 @@ class LookupTests(TestCase):
         field = Article._meta.get_field("author")
 
         # clear and re-cache
-        field.get_lookups.cache_clear()
+        field.get_class_lookups.cache_clear()
         self.assertNotIn("exactly", field.get_lookups())
 
         # registration should bust the cache
@@ -670,6 +670,37 @@ class RegisterLookupTests(SimpleTestCase):
             self.assertEqual(author_name.get_lookup("sw"), CustomStartsWith)
         self.assertIsNone(author_name.get_lookup("sw"))
 
+    def test_instance_lookup(self):
+        author_name = Author._meta.get_field("name")
+        author_alias = Author._meta.get_field("alias")
+        with register_lookup(author_name, CustomStartsWith):
+            self.assertEqual(author_name.instance_lookups, {"sw": CustomStartsWith})
+            self.assertEqual(author_name.get_lookup("sw"), CustomStartsWith)
+            self.assertIsNone(author_alias.get_lookup("sw"))
+        self.assertIsNone(author_name.get_lookup("sw"))
+        self.assertEqual(author_name.instance_lookups, {})
+        self.assertIsNone(author_alias.get_lookup("sw"))
+
+    def test_instance_lookup_override_class_lookups(self):
+        author_name = Author._meta.get_field("name")
+        author_alias = Author._meta.get_field("alias")
+        with register_lookup(models.CharField, CustomStartsWith, lookup_name="st_end"):
+            with register_lookup(author_alias, CustomEndsWith, lookup_name="st_end"):
+                self.assertEqual(author_name.get_lookup("st_end"), CustomStartsWith)
+                self.assertEqual(author_alias.get_lookup("st_end"), CustomEndsWith)
+            self.assertEqual(author_name.get_lookup("st_end"), CustomStartsWith)
+            self.assertEqual(author_alias.get_lookup("st_end"), CustomStartsWith)
+        self.assertIsNone(author_name.get_lookup("st_end"))
+        self.assertIsNone(author_alias.get_lookup("st_end"))
+
+    def test_instance_lookup_override(self):
+        author_name = Author._meta.get_field("name")
+        with register_lookup(author_name, CustomStartsWith, lookup_name="st_end"):
+            self.assertEqual(author_name.get_lookup("st_end"), CustomStartsWith)
+            author_name.register_lookup(CustomEndsWith, lookup_name="st_end")
+            self.assertEqual(author_name.get_lookup("st_end"), CustomEndsWith)
+        self.assertIsNone(author_name.get_lookup("st_end"))
+
     def test_lookup_on_transform(self):
         transform = Div3Transform
         with register_lookup(Div3Transform, CustomStartsWith):
@@ -682,10 +713,16 @@ class RegisterLookupTests(SimpleTestCase):
         self.assertEqual(transform.get_lookups(), {})
 
     def test_transform_on_field(self):
-        author_age = Author._meta.get_field("age")
-        with register_lookup(models.IntegerField, Div3Transform):
-            self.assertEqual(author_age.get_transform("div3"), Div3Transform)
-        self.assertIsNone(author_age.get_transform("div3"))
+        author_name = Author._meta.get_field("name")
+        author_alias = Author._meta.get_field("alias")
+        with register_lookup(models.CharField, Div3Transform):
+            self.assertEqual(author_alias.get_transform("div3"), Div3Transform)
+            self.assertEqual(author_name.get_transform("div3"), Div3Transform)
+        with register_lookup(author_alias, Div3Transform):
+            self.assertEqual(author_alias.get_transform("div3"), Div3Transform)
+            self.assertIsNone(author_name.get_transform("div3"))
+        self.assertIsNone(author_alias.get_transform("div3"))
+        self.assertIsNone(author_name.get_transform("div3"))
 
     def test_related_lookup(self):
         article_author = Article._meta.get_field("author")
@@ -693,3 +730,9 @@ class RegisterLookupTests(SimpleTestCase):
             self.assertIsNone(article_author.get_lookup("sw"))
         with register_lookup(models.ForeignKey, RelatedMoreThan):
             self.assertEqual(article_author.get_lookup("rmt"), RelatedMoreThan)
+
+    def test_instance_related_lookup(self):
+        article_author = Article._meta.get_field("author")
+        with register_lookup(article_author, RelatedMoreThan):
+            self.assertEqual(article_author.get_lookup("rmt"), RelatedMoreThan)
+        self.assertIsNone(article_author.get_lookup("rmt"))