Browse Source

Fixed #24141 -- Added QuerySet.contains().

Johan Schiff 4 years ago
parent
commit
d01709aae2

+ 21 - 0
django/db/models/query.py

@@ -808,6 +808,27 @@ class QuerySet:
             return self.query.has_results(using=self.db)
         return bool(self._result_cache)
 
+    def contains(self, obj):
+        """Return True if the queryset contains an object."""
+        self._not_support_combined_queries('contains')
+        if self._fields is not None:
+            raise TypeError(
+                'Cannot call QuerySet.contains() after .values() or '
+                '.values_list().'
+            )
+        try:
+            if obj._meta.concrete_model != self.model._meta.concrete_model:
+                return False
+        except AttributeError:
+            raise TypeError("'obj' must be a model instance.")
+        if obj.pk is None:
+            raise ValueError(
+                'QuerySet.contains() cannot be used on unsaved objects.'
+            )
+        if self._result_cache is not None:
+            return obj in self._result_cache
+        return self.filter(pk=obj.pk).exists()
+
     def _prefetch_related_objects(self):
         # This method can only be called once the result cache has been filled.
         prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)

+ 33 - 17
docs/ref/models/querysets.txt

@@ -2516,24 +2516,11 @@ if not. This tries to perform the query in the simplest and fastest way
 possible, but it *does* execute nearly the same query as a normal
 :class:`.QuerySet` query.
 
-:meth:`~.QuerySet.exists` is useful for searches relating to both
-object membership in a :class:`.QuerySet` and to the existence of any objects in
-a :class:`.QuerySet`, particularly in the context of a large :class:`.QuerySet`.
+:meth:`~.QuerySet.exists` is useful for searches relating to the existence of
+any objects in a :class:`.QuerySet`, particularly in the context of a large
+:class:`.QuerySet`.
 
-The most efficient method of finding whether a model with a unique field
-(e.g. ``primary_key``) is a member of a :class:`.QuerySet` is::
-
-    entry = Entry.objects.get(pk=123)
-    if some_queryset.filter(pk=entry.pk).exists():
-        print("Entry contained in queryset")
-
-Which will be faster than the following which requires evaluating and iterating
-through the entire queryset::
-
-    if entry in some_queryset:
-       print("Entry contained in QuerySet")
-
-And to find whether a queryset contains any items::
+To find whether a queryset contains any items::
 
     if some_queryset.exists():
         print("There is at least one object in some_queryset")
@@ -2552,6 +2539,35 @@ more overall work (one query for the existence check plus an extra one to later
 retrieve the results) than using ``bool(some_queryset)``, which retrieves the
 results and then checks if any were returned.
 
+``contains()``
+~~~~~~~~~~~~~~
+
+.. method:: contains(obj)
+
+.. versionadded:: 4.0
+
+Returns ``True`` if the :class:`.QuerySet` contains ``obj``, and ``False`` if
+not. This tries to perform the query in the simplest and fastest way possible.
+
+:meth:`contains` is useful for checking an object membership in a
+:class:`.QuerySet`, particularly in the context of a large :class:`.QuerySet`.
+
+To check whether a queryset contains a specific item::
+
+    if some_queryset.contains(obj):
+        print('Entry contained in queryset')
+
+This will be faster than the following which requires evaluating and iterating
+through the entire queryset::
+
+    if obj in some_queryset:
+        print('Entry contained in queryset')
+
+Like :meth:`exists`, if ``some_queryset`` has not yet been evaluated, but you
+know that it will be at some point, then using ``some_queryset.contains(obj)``
+will make an additional database query, generally resulting in slower overall
+performance.
+
 ``update()``
 ~~~~~~~~~~~~
 

+ 3 - 1
docs/releases/4.0.txt

@@ -216,7 +216,9 @@ Migrations
 Models
 ~~~~~~
 
-* ...
+* New :meth:`QuerySet.contains(obj) <.QuerySet.contains>` method returns
+  whether the queryset contains the given object. This tries to perform the
+  query in the simplest and fastest way possible.
 
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~

+ 6 - 0
docs/topics/db/optimization.txt

@@ -240,6 +240,12 @@ row in the results, even if it ends up only using a few columns. The
 lot of text data or for fields that might take a lot of processing to convert
 back to Python. As always, profile first, then optimize.
 
+Use ``QuerySet.contains(obj)``
+------------------------------
+
+...if you only want to find out if ``obj`` is in the queryset, rather than
+``if obj in queryset``.
+
 Use ``QuerySet.count()``
 ------------------------
 

+ 1 - 0
tests/basic/tests.py

@@ -602,6 +602,7 @@ class ManagerTest(SimpleTestCase):
         'only',
         'using',
         'exists',
+        'contains',
         'explain',
         '_insert',
         '_update',

+ 62 - 0
tests/queries/test_contains.py

@@ -0,0 +1,62 @@
+from django.test import TestCase
+
+from .models import DumbCategory, NamedCategory, ProxyCategory
+
+
+class ContainsTests(TestCase):
+    @classmethod
+    def setUpTestData(cls):
+        cls.category = DumbCategory.objects.create()
+        cls.proxy_category = ProxyCategory.objects.create()
+
+    def test_unsaved_obj(self):
+        msg = 'QuerySet.contains() cannot be used on unsaved objects.'
+        with self.assertRaisesMessage(ValueError, msg):
+            DumbCategory.objects.contains(DumbCategory())
+
+    def test_obj_type(self):
+        msg = "'obj' must be a model instance."
+        with self.assertRaisesMessage(TypeError, msg):
+            DumbCategory.objects.contains(object())
+
+    def test_values(self):
+        msg = 'Cannot call QuerySet.contains() after .values() or .values_list().'
+        with self.assertRaisesMessage(TypeError, msg):
+            DumbCategory.objects.values_list('pk').contains(self.category)
+        with self.assertRaisesMessage(TypeError, msg):
+            DumbCategory.objects.values('pk').contains(self.category)
+
+    def test_basic(self):
+        with self.assertNumQueries(1):
+            self.assertIs(DumbCategory.objects.contains(self.category), True)
+        # QuerySet.contains() doesn't evaluate a queryset.
+        with self.assertNumQueries(1):
+            self.assertIs(DumbCategory.objects.contains(self.category), True)
+
+    def test_evaluated_queryset(self):
+        qs = DumbCategory.objects.all()
+        proxy_qs = ProxyCategory.objects.all()
+        # Evaluate querysets.
+        list(qs)
+        list(proxy_qs)
+        with self.assertNumQueries(0):
+            self.assertIs(qs.contains(self.category), True)
+            self.assertIs(qs.contains(self.proxy_category), True)
+            self.assertIs(proxy_qs.contains(self.category), True)
+            self.assertIs(proxy_qs.contains(self.proxy_category), True)
+
+    def test_proxy_model(self):
+        with self.assertNumQueries(1):
+            self.assertIs(DumbCategory.objects.contains(self.proxy_category), True)
+        with self.assertNumQueries(1):
+            self.assertIs(ProxyCategory.objects.contains(self.category), True)
+
+    def test_wrong_model(self):
+        qs = DumbCategory.objects.all()
+        named_category = NamedCategory(name='category')
+        with self.assertNumQueries(0):
+            self.assertIs(qs.contains(named_category), False)
+        # Evaluate the queryset.
+        list(qs)
+        with self.assertNumQueries(0):
+            self.assertIs(qs.contains(named_category), False)

+ 6 - 0
tests/queries/test_qs_combinators.py

@@ -404,6 +404,12 @@ class QuerySetSetOperationTests(TestCase):
                         msg % (operation, combinator),
                     ):
                         getattr(getattr(qs, combinator)(qs), operation)()
+            with self.assertRaisesMessage(
+                NotSupportedError,
+                msg % ('contains', combinator),
+            ):
+                obj = Number.objects.first()
+                getattr(qs, combinator)(qs).contains(obj)
 
     def test_get_with_filters_unsupported_on_combined_qs(self):
         qs = Number.objects.all()