瀏覽代碼

Fixed #3871 -- Custom managers when traversing reverse relations.

Loic Bistuer 11 年之前
父節點
當前提交
04a2a6b0f9

+ 17 - 0
django/contrib/contenttypes/generic.py

@@ -319,6 +319,23 @@ def create_generic_related_manager(superclass):
                 '%s__exact' % object_id_field_name: instance._get_pk_val(),
             }
 
+        def __call__(self, **kwargs):
+            # We use **kwargs rather than a kwarg argument to enforce the
+            # `manager='manager_name'` syntax.
+            manager = getattr(self.model, kwargs.pop('manager'))
+            manager_class = create_generic_related_manager(manager.__class__)
+            return manager_class(
+                model = self.model,
+                instance = self.instance,
+                symmetrical = self.symmetrical,
+                source_col_name = self.source_col_name,
+                target_col_name = self.target_col_name,
+                content_type = self.content_type,
+                content_type_field_name = self.content_type_field_name,
+                object_id_field_name = self.object_id_field_name,
+                prefetch_cache_name = self.prefetch_cache_name,
+            )
+
         def get_queryset(self):
             try:
                 return self.instance._prefetched_objects_cache[self.prefetch_cache_name]

+ 108 - 80
django/db/models/fields/related.py

@@ -365,6 +365,92 @@ class ReverseSingleRelatedObjectDescriptor(six.with_metaclass(RenameRelatedObjec
             setattr(value, self.field.related.get_cache_name(), instance)
 
 
+def create_foreign_related_manager(superclass, rel_field, rel_model):
+    class RelatedManager(superclass):
+        def __init__(self, instance):
+            super(RelatedManager, self).__init__()
+            self.instance = instance
+            self.core_filters = {'%s__exact' % rel_field.name: instance}
+            self.model = rel_model
+
+        def __call__(self, **kwargs):
+            # We use **kwargs rather than a kwarg argument to enforce the
+            # `manager='manager_name'` syntax.
+            manager = getattr(self.model, kwargs.pop('manager'))
+            manager_class = create_foreign_related_manager(manager.__class__, rel_field, rel_model)
+            return manager_class(self.instance)
+
+        def get_queryset(self):
+            try:
+                return self.instance._prefetched_objects_cache[rel_field.related_query_name()]
+            except (AttributeError, KeyError):
+                db = self._db or router.db_for_read(self.model, instance=self.instance)
+                qs = super(RelatedManager, self).get_queryset().using(db).filter(**self.core_filters)
+                empty_strings_as_null = connections[db].features.interprets_empty_strings_as_nulls
+                for field in rel_field.foreign_related_fields:
+                    val = getattr(self.instance, field.attname)
+                    if val is None or (val == '' and empty_strings_as_null):
+                        return qs.none()
+                qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}}
+                return qs
+
+        def get_prefetch_queryset(self, instances):
+            rel_obj_attr = rel_field.get_local_related_value
+            instance_attr = rel_field.get_foreign_related_value
+            instances_dict = dict((instance_attr(inst), inst) for inst in instances)
+            db = self._db or router.db_for_read(self.model, instance=instances[0])
+            query = {'%s__in' % rel_field.name: instances}
+            qs = super(RelatedManager, self).get_queryset().using(db).filter(**query)
+            # Since we just bypassed this class' get_queryset(), we must manage
+            # the reverse relation manually.
+            for rel_obj in qs:
+                instance = instances_dict[rel_obj_attr(rel_obj)]
+                setattr(rel_obj, rel_field.name, instance)
+            cache_name = rel_field.related_query_name()
+            return qs, rel_obj_attr, instance_attr, False, cache_name
+
+        def add(self, *objs):
+            for obj in objs:
+                if not isinstance(obj, self.model):
+                    raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj))
+                setattr(obj, rel_field.name, self.instance)
+                obj.save()
+        add.alters_data = True
+
+        def create(self, **kwargs):
+            kwargs[rel_field.name] = self.instance
+            db = router.db_for_write(self.model, instance=self.instance)
+            return super(RelatedManager, self.db_manager(db)).create(**kwargs)
+        create.alters_data = True
+
+        def get_or_create(self, **kwargs):
+            # Update kwargs with the related object that this
+            # ForeignRelatedObjectsDescriptor knows about.
+            kwargs[rel_field.name] = self.instance
+            db = router.db_for_write(self.model, instance=self.instance)
+            return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs)
+        get_or_create.alters_data = True
+
+        # remove() and clear() are only provided if the ForeignKey can have a value of null.
+        if rel_field.null:
+            def remove(self, *objs):
+                val = rel_field.get_foreign_related_value(self.instance)
+                for obj in objs:
+                    # Is obj actually part of this descriptor set?
+                    if rel_field.get_local_related_value(obj) == val:
+                        setattr(obj, rel_field.name, None)
+                        obj.save()
+                    else:
+                        raise rel_field.rel.to.DoesNotExist("%r is not related to %r." % (obj, self.instance))
+            remove.alters_data = True
+
+            def clear(self):
+                self.update(**{rel_field.name: None})
+            clear.alters_data = True
+
+    return RelatedManager
+
+
 class ForeignRelatedObjectsDescriptor(object):
     # This class provides the functionality that makes the related-object
     # managers available as attributes on a model class, for fields that have
@@ -392,86 +478,11 @@ class ForeignRelatedObjectsDescriptor(object):
     def related_manager_cls(self):
         # Dynamically create a class that subclasses the related model's default
         # manager.
-        superclass = self.related.model._default_manager.__class__
-        rel_field = self.related.field
-        rel_model = self.related.model
-
-        class RelatedManager(superclass):
-            def __init__(self, instance):
-                super(RelatedManager, self).__init__()
-                self.instance = instance
-                self.core_filters = {'%s__exact' % rel_field.name: instance}
-                self.model = rel_model
-
-            def get_queryset(self):
-                try:
-                    return self.instance._prefetched_objects_cache[rel_field.related_query_name()]
-                except (AttributeError, KeyError):
-                    db = self._db or router.db_for_read(self.model, instance=self.instance)
-                    qs = super(RelatedManager, self).get_queryset().using(db).filter(**self.core_filters)
-                    empty_strings_as_null = connections[db].features.interprets_empty_strings_as_nulls
-                    for field in rel_field.foreign_related_fields:
-                        val = getattr(self.instance, field.attname)
-                        if val is None or (val == '' and empty_strings_as_null):
-                            return qs.none()
-                    qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}}
-                    return qs
-
-            def get_prefetch_queryset(self, instances):
-                rel_obj_attr = rel_field.get_local_related_value
-                instance_attr = rel_field.get_foreign_related_value
-                instances_dict = dict((instance_attr(inst), inst) for inst in instances)
-                db = self._db or router.db_for_read(self.model, instance=instances[0])
-                query = {'%s__in' % rel_field.name: instances}
-                qs = super(RelatedManager, self).get_queryset().using(db).filter(**query)
-                # Since we just bypassed this class' get_queryset(), we must manage
-                # the reverse relation manually.
-                for rel_obj in qs:
-                    instance = instances_dict[rel_obj_attr(rel_obj)]
-                    setattr(rel_obj, rel_field.name, instance)
-                cache_name = rel_field.related_query_name()
-                return qs, rel_obj_attr, instance_attr, False, cache_name
-
-            def add(self, *objs):
-                for obj in objs:
-                    if not isinstance(obj, self.model):
-                        raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj))
-                    setattr(obj, rel_field.name, self.instance)
-                    obj.save()
-            add.alters_data = True
-
-            def create(self, **kwargs):
-                kwargs[rel_field.name] = self.instance
-                db = router.db_for_write(self.model, instance=self.instance)
-                return super(RelatedManager, self.db_manager(db)).create(**kwargs)
-            create.alters_data = True
-
-            def get_or_create(self, **kwargs):
-                # Update kwargs with the related object that this
-                # ForeignRelatedObjectsDescriptor knows about.
-                kwargs[rel_field.name] = self.instance
-                db = router.db_for_write(self.model, instance=self.instance)
-                return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs)
-            get_or_create.alters_data = True
-
-            # remove() and clear() are only provided if the ForeignKey can have a value of null.
-            if rel_field.null:
-                def remove(self, *objs):
-                    val = rel_field.get_foreign_related_value(self.instance)
-                    for obj in objs:
-                        # Is obj actually part of this descriptor set?
-                        if rel_field.get_local_related_value(obj) == val:
-                            setattr(obj, rel_field.name, None)
-                            obj.save()
-                        else:
-                            raise rel_field.rel.to.DoesNotExist("%r is not related to %r." % (obj, self.instance))
-                remove.alters_data = True
-
-                def clear(self):
-                    self.update(**{rel_field.name: None})
-                clear.alters_data = True
-
-        return RelatedManager
+        return create_foreign_related_manager(
+            self.related.model._default_manager.__class__,
+            self.related.field,
+            self.related.model,
+        )
 
 
 def create_many_related_manager(superclass, rel):
@@ -513,6 +524,23 @@ def create_many_related_manager(superclass, rel):
                                  "a many-to-many relationship can be used." %
                                  instance.__class__.__name__)
 
+        def __call__(self, **kwargs):
+            # We use **kwargs rather than a kwarg argument to enforce the
+            # `manager='manager_name'` syntax.
+            manager = getattr(self.model, kwargs.pop('manager'))
+            manager_class = create_many_related_manager(manager.__class__, rel)
+            return manager_class(
+                model=self.model,
+                query_field_name=self.query_field_name,
+                instance=self.instance,
+                symmetrical=self.symmetrical,
+                source_field_name=self.source_field_name,
+                target_field_name=self.target_field_name,
+                reverse=self.reverse,
+                through=self.through,
+                prefetch_cache_name=self.prefetch_cache_name,
+            )
+
         def get_queryset(self):
             try:
                 return self.instance._prefetched_objects_cache[self.prefetch_cache_name]

+ 6 - 0
docs/releases/1.7.txt

@@ -92,6 +92,12 @@ The :meth:`QuerySet.as_manager() <django.db.models.query.QuerySet.as_manager>`
 class method has been added to :ref:`create Manager with QuerySet methods
 <create-manager-with-queryset-methods>`.
 
+Using a custom manager when traversing reverse relations
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+It is now possible to :ref:`specify a custom manager
+<using-custom-reverse-manager>` when traversing a reverse relationship.
+
 Admin shortcuts support time zones
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

+ 25 - 0
docs/topics/db/queries.txt

@@ -1136,6 +1136,31 @@ above example code would look like this::
     >>> b.entries.filter(headline__contains='Lennon')
     >>> b.entries.count()
 
+.. _using-custom-reverse-manager:
+
+Using a custom reverse manager
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. versionadded:: 1.7
+
+By default the :class:`~django.db.models.fields.related.RelatedManager` used
+for reverse relations is a subclass of the :ref:`default manager <manager-names>`
+for that model. If you would like to specify a different manager for a given
+query you can use the following syntax::
+
+    from django.db import models
+
+    class Entry(models.Model):
+        #...
+        objects = models.Manager() # Default Manager
+        entries = EntryManager() # Custom Manager
+
+    >>> b = Blog.objects.get(id=1)
+    >>> b.entry_set(manager='entries').all()
+
+Additional methods to handle related objects
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
 In addition to the :class:`~django.db.models.query.QuerySet` methods defined in
 "Retrieving objects" above, the :class:`~django.db.models.ForeignKey`
 :class:`~django.db.models.Manager` has additional methods used to handle the

+ 20 - 0
tests/custom_managers/models.py

@@ -11,6 +11,7 @@ returns.
 
 from __future__ import unicode_literals
 
+from django.contrib.contenttypes import generic
 from django.db import models
 from django.utils.encoding import python_2_unicode_compatible
 
@@ -63,12 +64,28 @@ class BaseCustomManager(models.Manager):
 
 CustomManager = BaseCustomManager.from_queryset(CustomQuerySet)
 
+class FunPeopleManager(models.Manager):
+    def get_queryset(self):
+        return super(FunPeopleManager, self).get_queryset().filter(fun=True)
+
+class BoringPeopleManager(models.Manager):
+    def get_queryset(self):
+        return super(BoringPeopleManager, self).get_queryset().filter(fun=False)
+
 @python_2_unicode_compatible
 class Person(models.Model):
     first_name = models.CharField(max_length=30)
     last_name = models.CharField(max_length=30)
     fun = models.BooleanField(default=False)
+
+    favorite_book = models.ForeignKey('Book', null=True, related_name='favorite_books')
+    favorite_thing_type = models.ForeignKey('contenttypes.ContentType', null=True)
+    favorite_thing_id = models.IntegerField(null=True)
+    favorite_thing = generic.GenericForeignKey('favorite_thing_type', 'favorite_thing_id')
+
     objects = PersonManager()
+    fun_people = FunPeopleManager()
+    boring_people = BoringPeopleManager()
 
     custom_queryset_default_manager = CustomQuerySet.as_manager()
     custom_queryset_custom_manager = CustomManager('hello')
@@ -84,6 +101,9 @@ class Book(models.Model):
     published_objects = PublishedBookManager()
     authors = models.ManyToManyField(Person, related_name='books')
 
+    favorite_things = generic.GenericRelation(Person,
+        content_type_field='favorite_thing_type', object_id_field='favorite_thing_id')
+
     def __str__(self):
         return self.title
 

+ 86 - 11
tests/custom_managers/tests.py

@@ -7,10 +7,15 @@ from .models import Person, Book, Car, PersonManager, PublishedBookManager
 
 
 class CustomManagerTests(TestCase):
-    def test_manager(self):
-        Person.objects.create(first_name="Bugs", last_name="Bunny", fun=True)
-        p2 = Person.objects.create(first_name="Droopy", last_name="Dog", fun=False)
+    def setUp(self):
+        self.b1 = Book.published_objects.create(
+            title="How to program", author="Rodney Dangerfield", is_published=True)
+        self.b2 = Book.published_objects.create(
+            title="How to be smart", author="Albert Einstein", is_published=False)
+        self.p1 = Person.objects.create(first_name="Bugs", last_name="Bunny", fun=True)
+        self.p2 = Person.objects.create(first_name="Droopy", last_name="Dog", fun=False)
 
+    def test_manager(self):
         # Test a custom `Manager` method.
         self.assertQuerysetEqual(
             Person.objects.get_fun_people(), [
@@ -61,14 +66,8 @@ class CustomManagerTests(TestCase):
 
         # The RelatedManager used on the 'books' descriptor extends the default
         # manager
-        self.assertIsInstance(p2.books, PublishedBookManager)
+        self.assertIsInstance(self.p2.books, PublishedBookManager)
 
-        Book.published_objects.create(
-            title="How to program", author="Rodney Dangerfield", is_published=True
-        )
-        b2 = Book.published_objects.create(
-            title="How to be smart", author="Albert Einstein", is_published=False
-        )
 
         # The default manager, "objects", doesn't exist, because a custom one
         # was provided.
@@ -76,7 +75,7 @@ class CustomManagerTests(TestCase):
 
         # The RelatedManager used on the 'authors' descriptor extends the
         # default manager
-        self.assertIsInstance(b2.authors, PersonManager)
+        self.assertIsInstance(self.b2.authors, PersonManager)
 
         self.assertQuerysetEqual(
             Book.published_objects.all(), [
@@ -114,3 +113,79 @@ class CustomManagerTests(TestCase):
             ],
             lambda c: c.name
         )
+
+    def test_related_manager_fk(self):
+        self.p1.favorite_book = self.b1
+        self.p1.save()
+        self.p2.favorite_book = self.b1
+        self.p2.save()
+
+        self.assertQuerysetEqual(
+            self.b1.favorite_books.order_by('first_name').all(), [
+                "Bugs",
+                "Droopy",
+            ],
+            lambda c: c.first_name
+        )
+        self.assertQuerysetEqual(
+            self.b1.favorite_books(manager='boring_people').all(), [
+                "Droopy",
+            ],
+            lambda c: c.first_name
+        )
+        self.assertQuerysetEqual(
+            self.b1.favorite_books(manager='fun_people').all(), [
+                "Bugs",
+            ],
+            lambda c: c.first_name
+        )
+
+    def test_related_manager_gfk(self):
+        self.p1.favorite_thing = self.b1
+        self.p1.save()
+        self.p2.favorite_thing = self.b1
+        self.p2.save()
+
+        self.assertQuerysetEqual(
+            self.b1.favorite_things.order_by('first_name').all(), [
+                "Bugs",
+                "Droopy",
+            ],
+            lambda c: c.first_name
+        )
+        self.assertQuerysetEqual(
+            self.b1.favorite_things(manager='boring_people').all(), [
+                "Droopy",
+            ],
+            lambda c: c.first_name
+        )
+        self.assertQuerysetEqual(
+            self.b1.favorite_things(manager='fun_people').all(), [
+                "Bugs",
+            ],
+            lambda c: c.first_name
+        )
+
+    def test_related_manager_m2m(self):
+        self.b1.authors.add(self.p1)
+        self.b1.authors.add(self.p2)
+
+        self.assertQuerysetEqual(
+            self.b1.authors.order_by('first_name').all(), [
+                "Bugs",
+                "Droopy",
+            ],
+            lambda c: c.first_name
+        )
+        self.assertQuerysetEqual(
+            self.b1.authors(manager='boring_people').all(), [
+                "Droopy",
+            ],
+            lambda c: c.first_name
+        )
+        self.assertQuerysetEqual(
+            self.b1.authors(manager='fun_people').all(), [
+                "Bugs",
+            ],
+            lambda c: c.first_name
+        )