Преглед изворни кода

Fixed #27498 -- Fixed filtering on annotated DecimalField on SQLite.

Peter Inglesby пре 8 година
родитељ
комит
a4cac17200
3 измењених фајлова са 86 додато и 1 уклоњено
  1. 38 1
      django/db/models/lookups.py
  2. 10 0
      tests/lookup/models.py
  3. 38 0
      tests/lookup/test_decimalfield.py

+ 38 - 1
django/db/models/lookups.py

@@ -2,10 +2,13 @@ import itertools
 import math
 import warnings
 from copy import copy
+from decimal import Decimal
 
 from django.core.exceptions import EmptyResultSet
 from django.db.models.expressions import Func, Value
-from django.db.models.fields import DateTimeField, Field, IntegerField
+from django.db.models.fields import (
+    DateTimeField, DecimalField, Field, IntegerField,
+)
 from django.db.models.query_utils import RegisterLookupMixin
 from django.utils.deprecation import RemovedInDjango20Warning
 from django.utils.functional import cached_property
@@ -306,6 +309,40 @@ class IntegerLessThan(IntegerFieldFloatRounding, LessThan):
 IntegerField.register_lookup(IntegerLessThan)
 
 
+class DecimalComparisonLookup(object):
+    def as_sqlite(self, compiler, connection):
+        lhs_sql, params = self.process_lhs(compiler, connection)
+        rhs_sql, rhs_params = self.process_rhs(compiler, connection)
+        params.extend(rhs_params)
+        # For comparisons whose lhs is a DecimalField, cast rhs AS NUMERIC
+        # because the rhs will have been converted to a string by the
+        # rev_typecast_decimal() adapter.
+        if isinstance(self.rhs, Decimal):
+            rhs_sql = 'CAST(%s AS NUMERIC)' % rhs_sql
+        rhs_sql = self.get_rhs_op(connection, rhs_sql)
+        return '%s %s' % (lhs_sql, rhs_sql), params
+
+
+@DecimalField.register_lookup
+class DecimalGreaterThan(DecimalComparisonLookup, GreaterThan):
+    pass
+
+
+@DecimalField.register_lookup
+class DecimalGreaterThanOrEqual(DecimalComparisonLookup, GreaterThanOrEqual):
+    pass
+
+
+@DecimalField.register_lookup
+class DecimalLessThan(DecimalComparisonLookup, LessThan):
+    pass
+
+
+@DecimalField.register_lookup
+class DecimalLessThanOrEqual(DecimalComparisonLookup, LessThanOrEqual):
+    pass
+
+
 class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
     lookup_name = 'in'
 

+ 10 - 0
tests/lookup/models.py

@@ -86,3 +86,13 @@ class MyISAMArticle(models.Model):
     class Meta:
         db_table = 'myisam_article'
         managed = False
+
+
+class Product(models.Model):
+    name = models.CharField(max_length=80)
+    qty_target = models.DecimalField(max_digits=6, decimal_places=2)
+
+
+class Stock(models.Model):
+    product = models.ForeignKey(Product, models.CASCADE)
+    qty_available = models.DecimalField(max_digits=6, decimal_places=2)

+ 38 - 0
tests/lookup/test_decimalfield.py

@@ -0,0 +1,38 @@
+from django.db.models.aggregates import Sum
+from django.db.models.expressions import F
+from django.test import TestCase
+
+from .models import Product, Stock
+
+
+class DecimalFieldLookupTests(TestCase):
+    @classmethod
+    def setUpTestData(cls):
+        cls.p1 = Product.objects.create(name='Product1', qty_target=10)
+        Stock.objects.create(product=cls.p1, qty_available=5)
+        Stock.objects.create(product=cls.p1, qty_available=6)
+        cls.p2 = Product.objects.create(name='Product2', qty_target=10)
+        Stock.objects.create(product=cls.p2, qty_available=5)
+        Stock.objects.create(product=cls.p2, qty_available=5)
+        cls.p3 = Product.objects.create(name='Product3', qty_target=10)
+        Stock.objects.create(product=cls.p3, qty_available=5)
+        Stock.objects.create(product=cls.p3, qty_available=4)
+        cls.queryset = Product.objects.annotate(
+            qty_available_sum=Sum('stock__qty_available'),
+        ).annotate(qty_needed=F('qty_target') - F('qty_available_sum'))
+
+    def test_gt(self):
+        qs = self.queryset.filter(qty_needed__gt=0)
+        self.assertCountEqual(qs, [self.p3])
+
+    def test_gte(self):
+        qs = self.queryset.filter(qty_needed__gte=0)
+        self.assertCountEqual(qs, [self.p2, self.p3])
+
+    def test_lt(self):
+        qs = self.queryset.filter(qty_needed__lt=0)
+        self.assertCountEqual(qs, [self.p1])
+
+    def test_lte(self):
+        qs = self.queryset.filter(qty_needed__lte=0)
+        self.assertCountEqual(qs, [self.p1, self.p2])