Browse Source

Fixed #16211 -- Added logical NOT support to F expressions.

David Wobrock 2 years ago
parent
commit
a320aab512

+ 58 - 25
django/db/models/expressions.py

@@ -162,6 +162,9 @@ class Combinable:
             "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
             "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
         )
         )
 
 
+    def __invert__(self):
+        return NegatedExpression(self)
+
 
 
 class BaseExpression:
 class BaseExpression:
     """Base class for all query expressions."""
     """Base class for all query expressions."""
@@ -827,6 +830,9 @@ class F(Combinable):
     def __hash__(self):
     def __hash__(self):
         return hash(self.name)
         return hash(self.name)
 
 
+    def copy(self):
+        return copy.copy(self)
+
 
 
 class ResolvedOuterRef(F):
 class ResolvedOuterRef(F):
     """
     """
@@ -1252,6 +1258,57 @@ class ExpressionWrapper(SQLiteNumericMixin, Expression):
         return "{}({})".format(self.__class__.__name__, self.expression)
         return "{}({})".format(self.__class__.__name__, self.expression)
 
 
 
 
+class NegatedExpression(ExpressionWrapper):
+    """The logical negation of a conditional expression."""
+
+    def __init__(self, expression):
+        super().__init__(expression, output_field=fields.BooleanField())
+
+    def __invert__(self):
+        return self.expression.copy()
+
+    def as_sql(self, compiler, connection):
+        try:
+            sql, params = super().as_sql(compiler, connection)
+        except EmptyResultSet:
+            features = compiler.connection.features
+            if not features.supports_boolean_expr_in_select_clause:
+                return "1=1", ()
+            return compiler.compile(Value(True))
+        ops = compiler.connection.ops
+        # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
+        # to be compared to another expression unless they're wrapped in a CASE
+        # WHEN.
+        if not ops.conditional_expression_supported_in_where_clause(self.expression):
+            return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
+        return f"NOT {sql}", params
+
+    def resolve_expression(
+        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
+    ):
+        resolved = super().resolve_expression(
+            query, allow_joins, reuse, summarize, for_save
+        )
+        if not getattr(resolved.expression, "conditional", False):
+            raise TypeError("Cannot negate non-conditional expressions.")
+        return resolved
+
+    def select_format(self, compiler, sql, params):
+        # Wrap boolean expressions with a CASE WHEN expression if a database
+        # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
+        # GROUP BY list.
+        expression_supported_in_where_clause = (
+            compiler.connection.ops.conditional_expression_supported_in_where_clause
+        )
+        if (
+            not compiler.connection.features.supports_boolean_expr_in_select_clause
+            # Avoid double wrapping.
+            and expression_supported_in_where_clause(self.expression)
+        ):
+            sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql)
+        return sql, params
+
+
 @deconstructible(path="django.db.models.When")
 @deconstructible(path="django.db.models.When")
 class When(Expression):
 class When(Expression):
     template = "WHEN %(condition)s THEN %(result)s"
     template = "WHEN %(condition)s THEN %(result)s"
@@ -1486,34 +1543,10 @@ class Exists(Subquery):
     template = "EXISTS(%(subquery)s)"
     template = "EXISTS(%(subquery)s)"
     output_field = fields.BooleanField()
     output_field = fields.BooleanField()
 
 
-    def __init__(self, queryset, negated=False, **kwargs):
-        self.negated = negated
+    def __init__(self, queryset, **kwargs):
         super().__init__(queryset, **kwargs)
         super().__init__(queryset, **kwargs)
         self.query = self.query.exists()
         self.query = self.query.exists()
 
 
-    def __invert__(self):
-        clone = self.copy()
-        clone.negated = not self.negated
-        return clone
-
-    def as_sql(self, compiler, connection, **extra_context):
-        try:
-            sql, params = super().as_sql(
-                compiler,
-                connection,
-                **extra_context,
-            )
-        except EmptyResultSet:
-            if self.negated:
-                features = compiler.connection.features
-                if not features.supports_boolean_expr_in_select_clause:
-                    return "1=1", ()
-                return compiler.compile(Value(True))
-            raise
-        if self.negated:
-            sql = "NOT {}".format(sql)
-        return sql, params
-
     def select_format(self, compiler, sql, params):
     def select_format(self, compiler, sql, params):
         # Wrap EXISTS() with a CASE WHEN expression if a database backend
         # Wrap EXISTS() with a CASE WHEN expression if a database backend
         # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
         # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP

+ 13 - 0
docs/ref/models/expressions.txt

@@ -255,6 +255,19 @@ is null) after companies that have been contacted::
     from django.db.models import F
     from django.db.models import F
     Company.objects.order_by(F('last_contacted').desc(nulls_last=True))
     Company.objects.order_by(F('last_contacted').desc(nulls_last=True))
 
 
+Using ``F()`` with logical operations
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. versionadded:: 4.2
+
+``F()`` expressions that output ``BooleanField`` can be logically negated with
+the inversion operator ``~F()``. For example, to swap the activation status of
+companies::
+
+    from django.db.models import F
+
+    Company.objects.update(is_active=~F('is_active'))
+
 .. _func-expressions:
 .. _func-expressions:
 
 
 ``Func()`` expressions
 ``Func()`` expressions

+ 6 - 0
docs/releases/4.2.txt

@@ -236,6 +236,9 @@ Models
 * :class:`~django.db.models.functions.Now` now supports microsecond precision
 * :class:`~django.db.models.functions.Now` now supports microsecond precision
   on MySQL and millisecond precision on SQLite.
   on MySQL and millisecond precision on SQLite.
 
 
+* :class:`F() <django.db.models.F>` expressions that output ``BooleanField``
+  can now be negated using ``~F()`` (inversion operator).
+
 Requests and Responses
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 ~~~~~~~~~~~~~~~~~~~~~~
 
 
@@ -345,6 +348,9 @@ Miscellaneous
 * The minimum supported version of ``sqlparse`` is increased from 0.2.2 to
 * The minimum supported version of ``sqlparse`` is increased from 0.2.2 to
   0.2.3.
   0.2.3.
 
 
+* The undocumented ``negated`` parameter of the
+  :class:`~django.db.models.Exists` expression is removed.
+
 .. _deprecated-features-4.2:
 .. _deprecated-features-4.2:
 
 
 Features deprecated in 4.2
 Features deprecated in 4.2

+ 56 - 0
tests/expressions/tests.py

@@ -48,6 +48,7 @@ from django.db.models.expressions import (
     Col,
     Col,
     Combinable,
     Combinable,
     CombinedExpression,
     CombinedExpression,
+    NegatedExpression,
     RawSQL,
     RawSQL,
     Ref,
     Ref,
 )
 )
@@ -2536,6 +2537,61 @@ class ExpressionWrapperTests(SimpleTestCase):
         self.assertEqual(group_by_cols[0].output_field, expr.output_field)
         self.assertEqual(group_by_cols[0].output_field, expr.output_field)
 
 
 
 
+class NegatedExpressionTests(TestCase):
+    @classmethod
+    def setUpTestData(cls):
+        ceo = Employee.objects.create(firstname="Joe", lastname="Smith", salary=10)
+        cls.eu_company = Company.objects.create(
+            name="Example Inc.",
+            num_employees=2300,
+            num_chairs=5,
+            ceo=ceo,
+            based_in_eu=True,
+        )
+        cls.non_eu_company = Company.objects.create(
+            name="Foobar Ltd.",
+            num_employees=3,
+            num_chairs=4,
+            ceo=ceo,
+            based_in_eu=False,
+        )
+
+    def test_invert(self):
+        f = F("field")
+        self.assertEqual(~f, NegatedExpression(f))
+        self.assertIsNot(~~f, f)
+        self.assertEqual(~~f, f)
+
+    def test_filter(self):
+        self.assertSequenceEqual(
+            Company.objects.filter(~F("based_in_eu")),
+            [self.non_eu_company],
+        )
+
+        qs = Company.objects.annotate(eu_required=~Value(False))
+        self.assertSequenceEqual(
+            qs.filter(based_in_eu=F("eu_required")).order_by("eu_required"),
+            [self.eu_company],
+        )
+        self.assertSequenceEqual(
+            qs.filter(based_in_eu=~~F("eu_required")),
+            [self.eu_company],
+        )
+        self.assertSequenceEqual(
+            qs.filter(based_in_eu=~F("eu_required")),
+            [self.non_eu_company],
+        )
+        self.assertSequenceEqual(qs.filter(based_in_eu=~F("based_in_eu")), [])
+
+    def test_values(self):
+        self.assertSequenceEqual(
+            Company.objects.annotate(negated=~F("based_in_eu"))
+            .values_list("name", "negated")
+            .order_by("name"),
+            [("Example Inc.", False), ("Foobar Ltd.", True)],
+        )
+
+
 class OrderByTests(SimpleTestCase):
 class OrderByTests(SimpleTestCase):
     def test_equal(self):
     def test_equal(self):
         self.assertEqual(
         self.assertEqual(

+ 2 - 2
tests/queries/test_q.py

@@ -8,7 +8,7 @@ from django.db.models import (
     Q,
     Q,
     Value,
     Value,
 )
 )
-from django.db.models.expressions import RawSQL
+from django.db.models.expressions import NegatedExpression, RawSQL
 from django.db.models.functions import Lower
 from django.db.models.functions import Lower
 from django.db.models.sql.where import NothingNode
 from django.db.models.sql.where import NothingNode
 from django.test import SimpleTestCase, TestCase
 from django.test import SimpleTestCase, TestCase
@@ -87,7 +87,7 @@ class QTests(SimpleTestCase):
         ]
         ]
         for q in tests:
         for q in tests:
             with self.subTest(q=q):
             with self.subTest(q=q):
-                self.assertIs(q.negated, True)
+                self.assertIsInstance(q, NegatedExpression)
 
 
     def test_deconstruct(self):
     def test_deconstruct(self):
         q = Q(price__gt=F("discounted_price"))
         q = Q(price__gt=F("discounted_price"))

+ 1 - 0
tests/update/models.py

@@ -10,6 +10,7 @@ class DataPoint(models.Model):
     name = models.CharField(max_length=20)
     name = models.CharField(max_length=20)
     value = models.CharField(max_length=20)
     value = models.CharField(max_length=20)
     another_value = models.CharField(max_length=20, blank=True)
     another_value = models.CharField(max_length=20, blank=True)
+    is_active = models.BooleanField(default=True)
 
 
 
 
 class RelatedPoint(models.Model):
 class RelatedPoint(models.Model):

+ 28 - 2
tests/update/tests.py

@@ -2,7 +2,7 @@ import unittest
 
 
 from django.core.exceptions import FieldError
 from django.core.exceptions import FieldError
 from django.db import IntegrityError, connection, transaction
 from django.db import IntegrityError, connection, transaction
-from django.db.models import CharField, Count, F, IntegerField, Max
+from django.db.models import Case, CharField, Count, F, IntegerField, Max, When
 from django.db.models.functions import Abs, Concat, Lower
 from django.db.models.functions import Abs, Concat, Lower
 from django.test import TestCase
 from django.test import TestCase
 from django.test.utils import register_lookup
 from django.test.utils import register_lookup
@@ -81,7 +81,7 @@ class AdvancedTests(TestCase):
     def setUpTestData(cls):
     def setUpTestData(cls):
         cls.d0 = DataPoint.objects.create(name="d0", value="apple")
         cls.d0 = DataPoint.objects.create(name="d0", value="apple")
         cls.d2 = DataPoint.objects.create(name="d2", value="banana")
         cls.d2 = DataPoint.objects.create(name="d2", value="banana")
-        cls.d3 = DataPoint.objects.create(name="d3", value="banana")
+        cls.d3 = DataPoint.objects.create(name="d3", value="banana", is_active=False)
         cls.r1 = RelatedPoint.objects.create(name="r1", data=cls.d3)
         cls.r1 = RelatedPoint.objects.create(name="r1", data=cls.d3)
 
 
     def test_update(self):
     def test_update(self):
@@ -249,6 +249,32 @@ class AdvancedTests(TestCase):
         Bar.objects.annotate(abs_id=Abs("m2m_foo")).order_by("abs_id").update(x=3)
         Bar.objects.annotate(abs_id=Abs("m2m_foo")).order_by("abs_id").update(x=3)
         self.assertEqual(Bar.objects.get().x, 3)
         self.assertEqual(Bar.objects.get().x, 3)
 
 
+    def test_update_negated_f(self):
+        DataPoint.objects.update(is_active=~F("is_active"))
+        self.assertCountEqual(
+            DataPoint.objects.values_list("name", "is_active"),
+            [("d0", False), ("d2", False), ("d3", True)],
+        )
+        DataPoint.objects.update(is_active=~F("is_active"))
+        self.assertCountEqual(
+            DataPoint.objects.values_list("name", "is_active"),
+            [("d0", True), ("d2", True), ("d3", False)],
+        )
+
+    def test_update_negated_f_conditional_annotation(self):
+        DataPoint.objects.annotate(
+            is_d2=Case(When(name="d2", then=True), default=False)
+        ).update(is_active=~F("is_d2"))
+        self.assertCountEqual(
+            DataPoint.objects.values_list("name", "is_active"),
+            [("d0", True), ("d2", False), ("d3", True)],
+        )
+
+    def test_updating_non_conditional_field(self):
+        msg = "Cannot negate non-conditional expressions."
+        with self.assertRaisesMessage(TypeError, msg):
+            DataPoint.objects.update(is_active=~F("name"))
+
 
 
 @unittest.skipUnless(
 @unittest.skipUnless(
     connection.vendor == "mysql",
     connection.vendor == "mysql",