Browse Source

Fixed #16211 -- Added logical NOT support to F expressions.

David Wobrock 2 năm trước cách đây
mục cha
commit
a320aab512

+ 58 - 25
django/db/models/expressions.py

@@ -162,6 +162,9 @@ class Combinable:
             "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
         )
 
+    def __invert__(self):
+        return NegatedExpression(self)
+
 
 class BaseExpression:
     """Base class for all query expressions."""
@@ -827,6 +830,9 @@ class F(Combinable):
     def __hash__(self):
         return hash(self.name)
 
+    def copy(self):
+        return copy.copy(self)
+
 
 class ResolvedOuterRef(F):
     """
@@ -1252,6 +1258,57 @@ class ExpressionWrapper(SQLiteNumericMixin, Expression):
         return "{}({})".format(self.__class__.__name__, self.expression)
 
 
+class NegatedExpression(ExpressionWrapper):
+    """The logical negation of a conditional expression."""
+
+    def __init__(self, expression):
+        super().__init__(expression, output_field=fields.BooleanField())
+
+    def __invert__(self):
+        return self.expression.copy()
+
+    def as_sql(self, compiler, connection):
+        try:
+            sql, params = super().as_sql(compiler, connection)
+        except EmptyResultSet:
+            features = compiler.connection.features
+            if not features.supports_boolean_expr_in_select_clause:
+                return "1=1", ()
+            return compiler.compile(Value(True))
+        ops = compiler.connection.ops
+        # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
+        # to be compared to another expression unless they're wrapped in a CASE
+        # WHEN.
+        if not ops.conditional_expression_supported_in_where_clause(self.expression):
+            return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
+        return f"NOT {sql}", params
+
+    def resolve_expression(
+        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
+    ):
+        resolved = super().resolve_expression(
+            query, allow_joins, reuse, summarize, for_save
+        )
+        if not getattr(resolved.expression, "conditional", False):
+            raise TypeError("Cannot negate non-conditional expressions.")
+        return resolved
+
+    def select_format(self, compiler, sql, params):
+        # Wrap boolean expressions with a CASE WHEN expression if a database
+        # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
+        # GROUP BY list.
+        expression_supported_in_where_clause = (
+            compiler.connection.ops.conditional_expression_supported_in_where_clause
+        )
+        if (
+            not compiler.connection.features.supports_boolean_expr_in_select_clause
+            # Avoid double wrapping.
+            and expression_supported_in_where_clause(self.expression)
+        ):
+            sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql)
+        return sql, params
+
+
 @deconstructible(path="django.db.models.When")
 class When(Expression):
     template = "WHEN %(condition)s THEN %(result)s"
@@ -1486,34 +1543,10 @@ class Exists(Subquery):
     template = "EXISTS(%(subquery)s)"
     output_field = fields.BooleanField()
 
-    def __init__(self, queryset, negated=False, **kwargs):
-        self.negated = negated
+    def __init__(self, queryset, **kwargs):
         super().__init__(queryset, **kwargs)
         self.query = self.query.exists()
 
-    def __invert__(self):
-        clone = self.copy()
-        clone.negated = not self.negated
-        return clone
-
-    def as_sql(self, compiler, connection, **extra_context):
-        try:
-            sql, params = super().as_sql(
-                compiler,
-                connection,
-                **extra_context,
-            )
-        except EmptyResultSet:
-            if self.negated:
-                features = compiler.connection.features
-                if not features.supports_boolean_expr_in_select_clause:
-                    return "1=1", ()
-                return compiler.compile(Value(True))
-            raise
-        if self.negated:
-            sql = "NOT {}".format(sql)
-        return sql, params
-
     def select_format(self, compiler, sql, params):
         # Wrap EXISTS() with a CASE WHEN expression if a database backend
         # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP

+ 13 - 0
docs/ref/models/expressions.txt

@@ -255,6 +255,19 @@ is null) after companies that have been contacted::
     from django.db.models import F
     Company.objects.order_by(F('last_contacted').desc(nulls_last=True))
 
+Using ``F()`` with logical operations
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. versionadded:: 4.2
+
+``F()`` expressions that output ``BooleanField`` can be logically negated with
+the inversion operator ``~F()``. For example, to swap the activation status of
+companies::
+
+    from django.db.models import F
+
+    Company.objects.update(is_active=~F('is_active'))
+
 .. _func-expressions:
 
 ``Func()`` expressions

+ 6 - 0
docs/releases/4.2.txt

@@ -236,6 +236,9 @@ Models
 * :class:`~django.db.models.functions.Now` now supports microsecond precision
   on MySQL and millisecond precision on SQLite.
 
+* :class:`F() <django.db.models.F>` expressions that output ``BooleanField``
+  can now be negated using ``~F()`` (inversion operator).
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 
@@ -345,6 +348,9 @@ Miscellaneous
 * The minimum supported version of ``sqlparse`` is increased from 0.2.2 to
   0.2.3.
 
+* The undocumented ``negated`` parameter of the
+  :class:`~django.db.models.Exists` expression is removed.
+
 .. _deprecated-features-4.2:
 
 Features deprecated in 4.2

+ 56 - 0
tests/expressions/tests.py

@@ -48,6 +48,7 @@ from django.db.models.expressions import (
     Col,
     Combinable,
     CombinedExpression,
+    NegatedExpression,
     RawSQL,
     Ref,
 )
@@ -2536,6 +2537,61 @@ class ExpressionWrapperTests(SimpleTestCase):
         self.assertEqual(group_by_cols[0].output_field, expr.output_field)
 
 
+class NegatedExpressionTests(TestCase):
+    @classmethod
+    def setUpTestData(cls):
+        ceo = Employee.objects.create(firstname="Joe", lastname="Smith", salary=10)
+        cls.eu_company = Company.objects.create(
+            name="Example Inc.",
+            num_employees=2300,
+            num_chairs=5,
+            ceo=ceo,
+            based_in_eu=True,
+        )
+        cls.non_eu_company = Company.objects.create(
+            name="Foobar Ltd.",
+            num_employees=3,
+            num_chairs=4,
+            ceo=ceo,
+            based_in_eu=False,
+        )
+
+    def test_invert(self):
+        f = F("field")
+        self.assertEqual(~f, NegatedExpression(f))
+        self.assertIsNot(~~f, f)
+        self.assertEqual(~~f, f)
+
+    def test_filter(self):
+        self.assertSequenceEqual(
+            Company.objects.filter(~F("based_in_eu")),
+            [self.non_eu_company],
+        )
+
+        qs = Company.objects.annotate(eu_required=~Value(False))
+        self.assertSequenceEqual(
+            qs.filter(based_in_eu=F("eu_required")).order_by("eu_required"),
+            [self.eu_company],
+        )
+        self.assertSequenceEqual(
+            qs.filter(based_in_eu=~~F("eu_required")),
+            [self.eu_company],
+        )
+        self.assertSequenceEqual(
+            qs.filter(based_in_eu=~F("eu_required")),
+            [self.non_eu_company],
+        )
+        self.assertSequenceEqual(qs.filter(based_in_eu=~F("based_in_eu")), [])
+
+    def test_values(self):
+        self.assertSequenceEqual(
+            Company.objects.annotate(negated=~F("based_in_eu"))
+            .values_list("name", "negated")
+            .order_by("name"),
+            [("Example Inc.", False), ("Foobar Ltd.", True)],
+        )
+
+
 class OrderByTests(SimpleTestCase):
     def test_equal(self):
         self.assertEqual(

+ 2 - 2
tests/queries/test_q.py

@@ -8,7 +8,7 @@ from django.db.models import (
     Q,
     Value,
 )
-from django.db.models.expressions import RawSQL
+from django.db.models.expressions import NegatedExpression, RawSQL
 from django.db.models.functions import Lower
 from django.db.models.sql.where import NothingNode
 from django.test import SimpleTestCase, TestCase
@@ -87,7 +87,7 @@ class QTests(SimpleTestCase):
         ]
         for q in tests:
             with self.subTest(q=q):
-                self.assertIs(q.negated, True)
+                self.assertIsInstance(q, NegatedExpression)
 
     def test_deconstruct(self):
         q = Q(price__gt=F("discounted_price"))

+ 1 - 0
tests/update/models.py

@@ -10,6 +10,7 @@ class DataPoint(models.Model):
     name = models.CharField(max_length=20)
     value = models.CharField(max_length=20)
     another_value = models.CharField(max_length=20, blank=True)
+    is_active = models.BooleanField(default=True)
 
 
 class RelatedPoint(models.Model):

+ 28 - 2
tests/update/tests.py

@@ -2,7 +2,7 @@ import unittest
 
 from django.core.exceptions import FieldError
 from django.db import IntegrityError, connection, transaction
-from django.db.models import CharField, Count, F, IntegerField, Max
+from django.db.models import Case, CharField, Count, F, IntegerField, Max, When
 from django.db.models.functions import Abs, Concat, Lower
 from django.test import TestCase
 from django.test.utils import register_lookup
@@ -81,7 +81,7 @@ class AdvancedTests(TestCase):
     def setUpTestData(cls):
         cls.d0 = DataPoint.objects.create(name="d0", value="apple")
         cls.d2 = DataPoint.objects.create(name="d2", value="banana")
-        cls.d3 = DataPoint.objects.create(name="d3", value="banana")
+        cls.d3 = DataPoint.objects.create(name="d3", value="banana", is_active=False)
         cls.r1 = RelatedPoint.objects.create(name="r1", data=cls.d3)
 
     def test_update(self):
@@ -249,6 +249,32 @@ class AdvancedTests(TestCase):
         Bar.objects.annotate(abs_id=Abs("m2m_foo")).order_by("abs_id").update(x=3)
         self.assertEqual(Bar.objects.get().x, 3)
 
+    def test_update_negated_f(self):
+        DataPoint.objects.update(is_active=~F("is_active"))
+        self.assertCountEqual(
+            DataPoint.objects.values_list("name", "is_active"),
+            [("d0", False), ("d2", False), ("d3", True)],
+        )
+        DataPoint.objects.update(is_active=~F("is_active"))
+        self.assertCountEqual(
+            DataPoint.objects.values_list("name", "is_active"),
+            [("d0", True), ("d2", True), ("d3", False)],
+        )
+
+    def test_update_negated_f_conditional_annotation(self):
+        DataPoint.objects.annotate(
+            is_d2=Case(When(name="d2", then=True), default=False)
+        ).update(is_active=~F("is_d2"))
+        self.assertCountEqual(
+            DataPoint.objects.values_list("name", "is_active"),
+            [("d0", True), ("d2", False), ("d3", True)],
+        )
+
+    def test_updating_non_conditional_field(self):
+        msg = "Cannot negate non-conditional expressions."
+        with self.assertRaisesMessage(TypeError, msg):
+            DataPoint.objects.update(is_active=~F("name"))
+
 
 @unittest.skipUnless(
     connection.vendor == "mysql",