Browse Source

Fixed #31606 -- Allowed using condition with lookups in When() expression.

Ryan Heard 4 years ago
parent
commit
587b179d41

+ 5 - 2
django/db/models/expressions.py

@@ -876,8 +876,11 @@ class When(Expression):
     conditional = False
     conditional = False
 
 
     def __init__(self, condition=None, then=None, **lookups):
     def __init__(self, condition=None, then=None, **lookups):
-        if lookups and condition is None:
-            condition, lookups = Q(**lookups), None
+        if lookups:
+            if condition is None:
+                condition, lookups = Q(**lookups), None
+            elif getattr(condition, 'conditional', False):
+                condition, lookups = Q(condition, **lookups), None
         if condition is None or not getattr(condition, 'conditional', False) or lookups:
         if condition is None or not getattr(condition, 'conditional', False) or lookups:
             raise TypeError(
             raise TypeError(
                 'When() supports a Q object, a boolean expression, or lookups '
                 'When() supports a Q object, a boolean expression, or lookups '

+ 4 - 0
docs/ref/models/conditional-expressions.txt

@@ -81,6 +81,10 @@ Keep in mind that each of these values can be an expression.
         >>> When(then__exact=0, then=1)
         >>> When(then__exact=0, then=1)
         >>> When(Q(then=0), then=1)
         >>> When(Q(then=0), then=1)
 
 
+.. versionchanged:: 3.2
+
+    Support for using the ``condition`` argument with ``lookups`` was added.
+
 ``Case``
 ``Case``
 --------
 --------
 
 

+ 3 - 0
docs/releases/3.2.txt

@@ -178,6 +178,9 @@ Models
   supported on PostgreSQL, allows acquiring weaker locks that don't block the
   supported on PostgreSQL, allows acquiring weaker locks that don't block the
   creation of rows that reference locked rows through a foreign key.
   creation of rows that reference locked rows through a foreign key.
 
 
+* :class:`When() <django.db.models.expressions.When>` expression now allows
+  using the ``condition`` argument with ``lookups``.
+
 Requests and Responses
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 ~~~~~~~~~~~~~~~~~~~~~~
 
 

+ 14 - 1
tests/expressions_case/tests.py

@@ -6,7 +6,7 @@ from uuid import UUID
 
 
 from django.core.exceptions import FieldError
 from django.core.exceptions import FieldError
 from django.db.models import (
 from django.db.models import (
-    BinaryField, Case, CharField, Count, DurationField, F,
+    BinaryField, BooleanField, Case, CharField, Count, DurationField, F,
     GenericIPAddressField, IntegerField, Max, Min, Q, Sum, TextField,
     GenericIPAddressField, IntegerField, Max, Min, Q, Sum, TextField,
     TimeField, UUIDField, Value, When,
     TimeField, UUIDField, Value, When,
 )
 )
@@ -312,6 +312,17 @@ class CaseExpressionTests(TestCase):
             transform=attrgetter('integer', 'integer2')
             transform=attrgetter('integer', 'integer2')
         )
         )
 
 
+    def test_condition_with_lookups(self):
+        qs = CaseTestModel.objects.annotate(
+            test=Case(
+                When(Q(integer2=1), string='2', then=Value(False)),
+                When(Q(integer2=1), string='1', then=Value(True)),
+                default=Value(False),
+                output_field=BooleanField(),
+            ),
+        )
+        self.assertIs(qs.get(integer=1).test, True)
+
     def test_case_reuse(self):
     def test_case_reuse(self):
         SOME_CASE = Case(
         SOME_CASE = Case(
             When(pk=0, then=Value('0')),
             When(pk=0, then=Value('0')),
@@ -1350,6 +1361,8 @@ class CaseWhenTests(SimpleTestCase):
             When(condition=object())
             When(condition=object())
         with self.assertRaisesMessage(TypeError, msg):
         with self.assertRaisesMessage(TypeError, msg):
             When(condition=Value(1, output_field=IntegerField()))
             When(condition=Value(1, output_field=IntegerField()))
+        with self.assertRaisesMessage(TypeError, msg):
+            When(Value(1, output_field=IntegerField()), string='1')
         with self.assertRaisesMessage(TypeError, msg):
         with self.assertRaisesMessage(TypeError, msg):
             When()
             When()