Browse Source

Fixed #29865 -- Added logical XOR support for Q() and querysets.

Ryan Heard 3 years ago
parent
commit
c6b4d62fa2

+ 1 - 0
AUTHORS

@@ -833,6 +833,7 @@ answer newbie questions, and generally made Django that much better:
     Russell Keith-Magee <russell@keith-magee.com>
     Russ Webber
     Ryan Hall <ryanhall989@gmail.com>
+    Ryan Heard <ryanwheard@gmail.com>
     ryankanno
     Ryan Kelly <ryan@rfk.id.au>
     Ryan Niemeyer <https://profiles.google.com/ryan.niemeyer/about>

+ 3 - 0
django/db/backends/base/features.py

@@ -325,6 +325,9 @@ class BaseDatabaseFeatures:
     # Does the backend support non-deterministic collations?
     supports_non_deterministic_collations = True
 
+    # Does the backend support the logical XOR operator?
+    supports_logical_xor = False
+
     # Collation names for use by the Django test suite.
     test_collations = {
         "ci": None,  # Case-insensitive.

+ 1 - 0
django/db/backends/mysql/features.py

@@ -47,6 +47,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
 
     supports_order_by_nulls_modifier = False
     order_by_nulls_first = True
+    supports_logical_xor = True
 
     @cached_property
     def minimum_database_version(self):

+ 16 - 4
django/db/models/expressions.py

@@ -94,7 +94,7 @@ class Combinable:
         if getattr(self, "conditional", False) and getattr(other, "conditional", False):
             return Q(self) & Q(other)
         raise NotImplementedError(
-            "Use .bitand() and .bitor() for bitwise logical operations."
+            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
         )
 
     def bitand(self, other):
@@ -106,6 +106,13 @@ class Combinable:
     def bitrightshift(self, other):
         return self._combine(other, self.BITRIGHTSHIFT, False)
 
+    def __xor__(self, other):
+        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
+            return Q(self) ^ Q(other)
+        raise NotImplementedError(
+            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
+        )
+
     def bitxor(self, other):
         return self._combine(other, self.BITXOR, False)
 
@@ -113,7 +120,7 @@ class Combinable:
         if getattr(self, "conditional", False) and getattr(other, "conditional", False):
             return Q(self) | Q(other)
         raise NotImplementedError(
-            "Use .bitand() and .bitor() for bitwise logical operations."
+            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
         )
 
     def bitor(self, other):
@@ -139,12 +146,17 @@ class Combinable:
 
     def __rand__(self, other):
         raise NotImplementedError(
-            "Use .bitand() and .bitor() for bitwise logical operations."
+            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
         )
 
     def __ror__(self, other):
         raise NotImplementedError(
-            "Use .bitand() and .bitor() for bitwise logical operations."
+            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
+        )
+
+    def __rxor__(self, other):
+        raise NotImplementedError(
+            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
         )
 
 

+ 19 - 0
django/db/models/query.py

@@ -396,6 +396,25 @@ class QuerySet:
         combined.query.combine(other.query, sql.OR)
         return combined
 
+    def __xor__(self, other):
+        self._check_operator_queryset(other, "^")
+        self._merge_sanity_check(other)
+        if isinstance(self, EmptyQuerySet):
+            return other
+        if isinstance(other, EmptyQuerySet):
+            return self
+        query = (
+            self
+            if self.query.can_filter()
+            else self.model._base_manager.filter(pk__in=self.values("pk"))
+        )
+        combined = query._chain()
+        combined._merge_known_related_objects(other)
+        if not other.query.can_filter():
+            other = other.model._base_manager.filter(pk__in=other.values("pk"))
+        combined.query.combine(other.query, sql.XOR)
+        return combined
+
     ####################################
     # METHODS THAT DO DATABASE QUERIES #
     ####################################

+ 4 - 0
django/db/models/query_utils.py

@@ -38,6 +38,7 @@ class Q(tree.Node):
     # Connection types
     AND = "AND"
     OR = "OR"
+    XOR = "XOR"
     default = AND
     conditional = True
 
@@ -70,6 +71,9 @@ class Q(tree.Node):
     def __and__(self, other):
         return self._combine(other, self.AND)
 
+    def __xor__(self, other):
+        return self._combine(other, self.XOR)
+
     def __invert__(self):
         obj = type(self)()
         obj.add(self, self.AND)

+ 2 - 2
django/db/models/sql/__init__.py

@@ -1,6 +1,6 @@
 from django.db.models.sql.query import *  # NOQA
 from django.db.models.sql.query import Query
 from django.db.models.sql.subqueries import *  # NOQA
-from django.db.models.sql.where import AND, OR
+from django.db.models.sql.where import AND, OR, XOR
 
-__all__ = ["Query", "AND", "OR"]
+__all__ = ["Query", "AND", "OR", "XOR"]

+ 26 - 4
django/db/models/sql/where.py

@@ -1,14 +1,19 @@
 """
 Code to manage the creation and SQL rendering of 'where' constraints.
 """
+import operator
+from functools import reduce
 
 from django.core.exceptions import EmptyResultSet
+from django.db.models.expressions import Case, When
+from django.db.models.lookups import Exact
 from django.utils import tree
 from django.utils.functional import cached_property
 
 # Connection types
 AND = "AND"
 OR = "OR"
+XOR = "XOR"
 
 
 class WhereNode(tree.Node):
@@ -39,10 +44,12 @@ class WhereNode(tree.Node):
         if not self.contains_aggregate:
             return self, None
         in_negated = negated ^ self.negated
-        # If the effective connector is OR and this node contains an aggregate,
-        # then we need to push the whole branch to HAVING clause.
-        may_need_split = (in_negated and self.connector == AND) or (
-            not in_negated and self.connector == OR
+        # If the effective connector is OR or XOR and this node contains an
+        # aggregate, then we need to push the whole branch to HAVING clause.
+        may_need_split = (
+            (in_negated and self.connector == AND)
+            or (not in_negated and self.connector == OR)
+            or self.connector == XOR
         )
         if may_need_split and self.contains_aggregate:
             return None, self
@@ -85,6 +92,21 @@ class WhereNode(tree.Node):
         else:
             full_needed, empty_needed = 1, len(self.children)
 
+        if self.connector == XOR and not connection.features.supports_logical_xor:
+            # Convert if the database doesn't support XOR:
+            #   a XOR b XOR c XOR ...
+            # to:
+            #   (a OR b OR c OR ...) AND (a + b + c + ...) == 1
+            lhs = self.__class__(self.children, OR)
+            rhs_sum = reduce(
+                operator.add,
+                (Case(When(c, then=1), default=0) for c in self.children),
+            )
+            rhs = Exact(1, rhs_sum)
+            return self.__class__([lhs, rhs], AND, self.negated).as_sql(
+                compiler, connection
+            )
+
         for child in self.children:
             try:
                 sql, params = compiler.compile(child)

+ 40 - 2
docs/ref/models/querysets.txt

@@ -1903,6 +1903,40 @@ SQL equivalent:
 ``|`` is not a commutative operation, as different (though equivalent) queries
 may be generated.
 
+XOR (``^``)
+~~~~~~~~~~~
+
+.. versionadded:: 4.1
+
+Combines two ``QuerySet``\s using the SQL ``XOR`` operator.
+
+The following are equivalent::
+
+    Model.objects.filter(x=1) ^ Model.objects.filter(y=2)
+    from django.db.models import Q
+    Model.objects.filter(Q(x=1) ^ Q(y=2))
+
+SQL equivalent:
+
+.. code-block:: sql
+
+    SELECT ... WHERE x=1 XOR y=2
+
+.. note::
+
+    ``XOR`` is natively supported on MariaDB and MySQL. On other databases,
+    ``x ^ y ^ ... ^ z`` is converted to an equivalent:
+
+    .. code-block:: sql
+
+        (x OR y OR ... OR z) AND
+        1=(
+            (CASE WHEN x THEN 1 ELSE 0 END) +
+            (CASE WHEN y THEN 1 ELSE 0 END) +
+            ...
+            (CASE WHEN z THEN 1 ELSE 0 END) +
+        )
+
 Methods that do not return ``QuerySet``\s
 -----------------------------------------
 
@@ -3751,8 +3785,12 @@ A ``Q()`` object represents an SQL condition that can be used in
 database-related operations. It's similar to how an
 :class:`F() <django.db.models.F>` object represents the value of a model field
 or annotation. They make it possible to define and reuse conditions, and
-combine them using operators such as ``|`` (``OR``) and ``&`` (``AND``). See
-:ref:`complex-lookups-with-q`.
+combine them using operators such as ``|`` (``OR``), ``&`` (``AND``), and ``^``
+(``XOR``). See :ref:`complex-lookups-with-q`.
+
+.. versionchanged:: 4.1
+
+    Support for the ``^`` (``XOR``) operator was added.
 
 ``Prefetch()`` objects
 ----------------------

+ 5 - 0
docs/releases/4.1.txt

@@ -273,6 +273,11 @@ Models
   as the ``chunk_size`` argument is provided. In older versions, no prefetching
   was done.
 
+* :class:`~django.db.models.Q` objects and querysets can now be combined using
+  ``^`` as the exclusive or (``XOR``) operator. ``XOR`` is natively supported
+  on MariaDB and MySQL. For databases that do not support ``XOR``, the query
+  will be converted to an equivalent using ``AND``, ``OR``, and ``NOT``.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 

+ 10 - 5
docs/topics/db/queries.txt

@@ -1111,8 +1111,8 @@ For example, this ``Q`` object encapsulates a single ``LIKE`` query::
     from django.db.models import Q
     Q(question__startswith='What')
 
-``Q`` objects can be combined using the ``&`` and ``|`` operators. When an
-operator is used on two ``Q`` objects, it yields a new ``Q`` object.
+``Q`` objects can be combined using the ``&``, ``|``, and ``^`` operators. When
+an operator is used on two ``Q`` objects, it yields a new ``Q`` object.
 
 For example, this statement yields a single ``Q`` object that represents the
 "OR" of two ``"question__startswith"`` queries::
@@ -1124,9 +1124,10 @@ This is equivalent to the following SQL ``WHERE`` clause::
     WHERE question LIKE 'Who%' OR question LIKE 'What%'
 
 You can compose statements of arbitrary complexity by combining ``Q`` objects
-with the ``&`` and ``|`` operators and use parenthetical grouping. Also, ``Q``
-objects can be negated using the ``~`` operator, allowing for combined lookups
-that combine both a normal query and a negated (``NOT``) query::
+with the ``&``, ``|``, and ``^`` operators and use parenthetical grouping.
+Also, ``Q`` objects can be negated using the ``~`` operator, allowing for
+combined lookups that combine both a normal query and a negated (``NOT``)
+query::
 
     Q(question__startswith='Who') | ~Q(pub_date__year=2005)
 
@@ -1175,6 +1176,10 @@ precede the definition of any keyword arguments. For example::
     The :source:`OR lookups examples <tests/or_lookups/tests.py>` in Django's
     unit tests show some possible uses of ``Q``.
 
+.. versionchanged:: 4.1
+
+    Support for the ``^`` (``XOR``) operator was added.
+
 Comparing objects
 =================
 

+ 22 - 0
tests/aggregation_regress/tests.py

@@ -1704,6 +1704,28 @@ class AggregationTests(TestCase):
             attrgetter("pk"),
         )
 
+    def test_filter_aggregates_xor_connector(self):
+        q1 = Q(price__gt=50)
+        q2 = Q(authors__count__gt=1)
+        query = Book.objects.annotate(Count("authors")).filter(q1 ^ q2).order_by("pk")
+        self.assertQuerysetEqual(
+            query,
+            [self.b1.pk, self.b4.pk, self.b6.pk],
+            attrgetter("pk"),
+        )
+
+    def test_filter_aggregates_negated_xor_connector(self):
+        q1 = Q(price__gt=50)
+        q2 = Q(authors__count__gt=1)
+        query = (
+            Book.objects.annotate(Count("authors")).filter(~(q1 ^ q2)).order_by("pk")
+        )
+        self.assertQuerysetEqual(
+            query,
+            [self.b2.pk, self.b3.pk, self.b5.pk],
+            attrgetter("pk"),
+        )
+
     def test_ticket_11293_q_immutable(self):
         """
         Splitting a q object to parts for where/having doesn't alter

+ 11 - 1
tests/expressions/tests.py

@@ -2339,7 +2339,9 @@ class ReprTests(SimpleTestCase):
 
 
 class CombinableTests(SimpleTestCase):
-    bitwise_msg = "Use .bitand() and .bitor() for bitwise logical operations."
+    bitwise_msg = (
+        "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
+    )
 
     def test_negation(self):
         c = Combinable()
@@ -2353,6 +2355,10 @@ class CombinableTests(SimpleTestCase):
         with self.assertRaisesMessage(NotImplementedError, self.bitwise_msg):
             Combinable() | Combinable()
 
+    def test_xor(self):
+        with self.assertRaisesMessage(NotImplementedError, self.bitwise_msg):
+            Combinable() ^ Combinable()
+
     def test_reversed_and(self):
         with self.assertRaisesMessage(NotImplementedError, self.bitwise_msg):
             object() & Combinable()
@@ -2361,6 +2367,10 @@ class CombinableTests(SimpleTestCase):
         with self.assertRaisesMessage(NotImplementedError, self.bitwise_msg):
             object() | Combinable()
 
+    def test_reversed_xor(self):
+        with self.assertRaisesMessage(NotImplementedError, self.bitwise_msg):
+            object() ^ Combinable()
+
 
 class CombinedExpressionTests(SimpleTestCase):
     def test_resolve_output_field(self):

+ 38 - 0
tests/queries/test_q.py

@@ -27,6 +27,15 @@ class QTests(SimpleTestCase):
         self.assertEqual(q | Q(), q)
         self.assertEqual(Q() | q, q)
 
+    def test_combine_xor_empty(self):
+        q = Q(x=1)
+        self.assertEqual(q ^ Q(), q)
+        self.assertEqual(Q() ^ q, q)
+
+        q = Q(x__in={}.keys())
+        self.assertEqual(q ^ Q(), q)
+        self.assertEqual(Q() ^ q, q)
+
     def test_combine_empty_copy(self):
         base_q = Q(x=1)
         tests = [
@@ -34,6 +43,8 @@ class QTests(SimpleTestCase):
             Q() | base_q,
             base_q & Q(),
             Q() & base_q,
+            base_q ^ Q(),
+            Q() ^ base_q,
         ]
         for i, q in enumerate(tests):
             with self.subTest(i=i):
@@ -43,6 +54,9 @@ class QTests(SimpleTestCase):
     def test_combine_or_both_empty(self):
         self.assertEqual(Q() | Q(), Q())
 
+    def test_combine_xor_both_empty(self):
+        self.assertEqual(Q() ^ Q(), Q())
+
     def test_combine_not_q_object(self):
         obj = object()
         q = Q(x=1)
@@ -50,12 +64,15 @@ class QTests(SimpleTestCase):
             q | obj
         with self.assertRaisesMessage(TypeError, str(obj)):
             q & obj
+        with self.assertRaisesMessage(TypeError, str(obj)):
+            q ^ obj
 
     def test_combine_negated_boolean_expression(self):
         tagged = Tag.objects.filter(category=OuterRef("pk"))
         tests = [
             Q() & ~Exists(tagged),
             Q() | ~Exists(tagged),
+            Q() ^ ~Exists(tagged),
         ]
         for q in tests:
             with self.subTest(q=q):
@@ -88,6 +105,20 @@ class QTests(SimpleTestCase):
         )
         self.assertEqual(kwargs, {"_connector": "OR"})
 
+    def test_deconstruct_xor(self):
+        q1 = Q(price__gt=F("discounted_price"))
+        q2 = Q(price=F("discounted_price"))
+        q = q1 ^ q2
+        path, args, kwargs = q.deconstruct()
+        self.assertEqual(
+            args,
+            (
+                ("price__gt", F("discounted_price")),
+                ("price", F("discounted_price")),
+            ),
+        )
+        self.assertEqual(kwargs, {"_connector": "XOR"})
+
     def test_deconstruct_and(self):
         q1 = Q(price__gt=F("discounted_price"))
         q2 = Q(price=F("discounted_price"))
@@ -144,6 +175,13 @@ class QTests(SimpleTestCase):
         path, args, kwargs = q.deconstruct()
         self.assertEqual(Q(*args, **kwargs), q)
 
+    def test_reconstruct_xor(self):
+        q1 = Q(price__gt=F("discounted_price"))
+        q2 = Q(price=F("discounted_price"))
+        q = q1 ^ q2
+        path, args, kwargs = q.deconstruct()
+        self.assertEqual(Q(*args, **kwargs), q)
+
     def test_reconstruct_and(self):
         q1 = Q(price__gt=F("discounted_price"))
         q2 = Q(price=F("discounted_price"))

+ 1 - 0
tests/queries/test_qs_combinators.py

@@ -526,6 +526,7 @@ class QuerySetSetOperationTests(TestCase):
         operators = [
             ("|", operator.or_),
             ("&", operator.and_),
+            ("^", operator.xor),
         ]
         for combinator in combinators:
             combined_qs = getattr(qs, combinator)(qs)

+ 37 - 0
tests/queries/tests.py

@@ -1883,6 +1883,10 @@ class Queries5Tests(TestCase):
             Note.objects.exclude(~Q() & ~Q()),
             [self.n1, self.n2],
         )
+        self.assertSequenceEqual(
+            Note.objects.exclude(~Q() ^ ~Q()),
+            [self.n1, self.n2],
+        )
 
     def test_extra_select_literal_percent_s(self):
         # Allow %%s to escape select clauses
@@ -2129,6 +2133,15 @@ class Queries6Tests(TestCase):
         sql = captured_queries[0]["sql"]
         self.assertIn("AS %s" % connection.ops.quote_name("col1"), sql)
 
+    def test_xor_subquery(self):
+        self.assertSequenceEqual(
+            Tag.objects.filter(
+                Exists(Tag.objects.filter(id=OuterRef("id"), name="t3"))
+                ^ Exists(Tag.objects.filter(id=OuterRef("id"), parent=self.t1))
+            ),
+            [self.t2],
+        )
+
 
 class RawQueriesTests(TestCase):
     @classmethod
@@ -2432,6 +2445,30 @@ class QuerySetBitwiseOperationTests(TestCase):
         qs2 = Classroom.objects.filter(has_blackboard=True).order_by("-name")[:1]
         self.assertCountEqual(qs1 | qs2, [self.room_3, self.room_4])
 
+    @skipUnlessDBFeature("allow_sliced_subqueries_with_in")
+    def test_xor_with_rhs_slice(self):
+        qs1 = Classroom.objects.filter(has_blackboard=True)
+        qs2 = Classroom.objects.filter(has_blackboard=False)[:1]
+        self.assertCountEqual(qs1 ^ qs2, [self.room_1, self.room_2, self.room_3])
+
+    @skipUnlessDBFeature("allow_sliced_subqueries_with_in")
+    def test_xor_with_lhs_slice(self):
+        qs1 = Classroom.objects.filter(has_blackboard=True)[:1]
+        qs2 = Classroom.objects.filter(has_blackboard=False)
+        self.assertCountEqual(qs1 ^ qs2, [self.room_1, self.room_2, self.room_4])
+
+    @skipUnlessDBFeature("allow_sliced_subqueries_with_in")
+    def test_xor_with_both_slice(self):
+        qs1 = Classroom.objects.filter(has_blackboard=False)[:1]
+        qs2 = Classroom.objects.filter(has_blackboard=True)[:1]
+        self.assertCountEqual(qs1 ^ qs2, [self.room_1, self.room_2])
+
+    @skipUnlessDBFeature("allow_sliced_subqueries_with_in")
+    def test_xor_with_both_slice_and_ordering(self):
+        qs1 = Classroom.objects.filter(has_blackboard=False).order_by("-pk")[:1]
+        qs2 = Classroom.objects.filter(has_blackboard=True).order_by("-name")[:1]
+        self.assertCountEqual(qs1 ^ qs2, [self.room_3, self.room_4])
+
     def test_subquery_aliases(self):
         combined = School.objects.filter(pk__isnull=False) & School.objects.filter(
             Exists(

+ 0 - 0
tests/xor_lookups/__init__.py


+ 8 - 0
tests/xor_lookups/models.py

@@ -0,0 +1,8 @@
+from django.db import models
+
+
+class Number(models.Model):
+    num = models.IntegerField()
+
+    def __str__(self):
+        return str(self.num)

+ 67 - 0
tests/xor_lookups/tests.py

@@ -0,0 +1,67 @@
+from django.db.models import Q
+from django.test import TestCase
+
+from .models import Number
+
+
+class XorLookupsTests(TestCase):
+    @classmethod
+    def setUpTestData(cls):
+        cls.numbers = [Number.objects.create(num=i) for i in range(10)]
+
+    def test_filter(self):
+        self.assertCountEqual(
+            Number.objects.filter(num__lte=7) ^ Number.objects.filter(num__gte=3),
+            self.numbers[:3] + self.numbers[8:],
+        )
+        self.assertCountEqual(
+            Number.objects.filter(Q(num__lte=7) ^ Q(num__gte=3)),
+            self.numbers[:3] + self.numbers[8:],
+        )
+
+    def test_filter_negated(self):
+        self.assertCountEqual(
+            Number.objects.filter(Q(num__lte=7) ^ ~Q(num__lt=3)),
+            self.numbers[:3] + self.numbers[8:],
+        )
+        self.assertCountEqual(
+            Number.objects.filter(~Q(num__gt=7) ^ ~Q(num__lt=3)),
+            self.numbers[:3] + self.numbers[8:],
+        )
+        self.assertCountEqual(
+            Number.objects.filter(Q(num__lte=7) ^ ~Q(num__lt=3) ^ Q(num__lte=1)),
+            [self.numbers[2]] + self.numbers[8:],
+        )
+        self.assertCountEqual(
+            Number.objects.filter(~(Q(num__lte=7) ^ ~Q(num__lt=3) ^ Q(num__lte=1))),
+            self.numbers[:2] + self.numbers[3:8],
+        )
+
+    def test_exclude(self):
+        self.assertCountEqual(
+            Number.objects.exclude(Q(num__lte=7) ^ Q(num__gte=3)),
+            self.numbers[3:8],
+        )
+
+    def test_stages(self):
+        numbers = Number.objects.all()
+        self.assertSequenceEqual(
+            numbers.filter(num__gte=0) ^ numbers.filter(num__lte=11),
+            [],
+        )
+        self.assertSequenceEqual(
+            numbers.filter(num__gt=0) ^ numbers.filter(num__lt=11),
+            [self.numbers[0]],
+        )
+
+    def test_pk_q(self):
+        self.assertCountEqual(
+            Number.objects.filter(Q(pk=self.numbers[0].pk) ^ Q(pk=self.numbers[1].pk)),
+            self.numbers[:2],
+        )
+
+    def test_empty_in(self):
+        self.assertCountEqual(
+            Number.objects.filter(Q(pk__in=[]) ^ Q(num__gte=5)),
+            self.numbers[5:],
+        )