Browse Source

Fixed #33257 -- Fixed Case() and ExpressionWrapper() with decimal values on SQLite.

Matthijs Kooijman 3 years ago
parent
commit
1a5023883b
3 changed files with 18 additions and 2 deletions
  1. 2 2
      django/db/models/expressions.py
  2. 7 0
      tests/expressions/tests.py
  3. 9 0
      tests/expressions_case/tests.py

+ 2 - 2
django/db/models/expressions.py

@@ -933,7 +933,7 @@ class ExpressionList(Func):
         return self.as_sql(compiler, connection, **extra_context)
 
 
-class ExpressionWrapper(Expression):
+class ExpressionWrapper(SQLiteNumericMixin, Expression):
     """
     An expression that can wrap another expression so that it can provide
     extra context to the inner expression, such as the output_field.
@@ -1032,7 +1032,7 @@ class When(Expression):
         return cols
 
 
-class Case(Expression):
+class Case(SQLiteNumericMixin, Expression):
     """
     An SQL searched CASE expression:
 

+ 7 - 0
tests/expressions/tests.py

@@ -1178,6 +1178,13 @@ class ExpressionsNumericTests(TestCase):
             ordered=False
         )
 
+    def test_filter_decimal_expression(self):
+        obj = Number.objects.create(integer=0, float=1, decimal_value=Decimal('1'))
+        qs = Number.objects.annotate(
+            x=ExpressionWrapper(Value(1), output_field=DecimalField()),
+        ).filter(Q(x=1, integer=0) & Q(x=Decimal('1')))
+        self.assertSequenceEqual(qs, [obj])
+
     def test_complex_expressions(self):
         """
         Complex expressions of different connection types are possible.

+ 9 - 0
tests/expressions_case/tests.py

@@ -256,6 +256,15 @@ class CaseExpressionTests(TestCase):
             transform=attrgetter('integer', 'test')
         )
 
+    def test_annotate_filter_decimal(self):
+        obj = CaseTestModel.objects.create(integer=0, decimal=Decimal('1'))
+        qs = CaseTestModel.objects.annotate(
+            x=Case(When(integer=0, then=F('decimal'))),
+            y=Case(When(integer=0, then=Value(Decimal('1')))),
+        )
+        self.assertSequenceEqual(qs.filter(Q(x=1) & Q(x=Decimal('1'))), [obj])
+        self.assertSequenceEqual(qs.filter(Q(y=1) & Q(y=Decimal('1'))), [obj])
+
     def test_annotate_values_not_in_order_by(self):
         self.assertEqual(
             list(CaseTestModel.objects.annotate(test=Case(