浏览代码

Fixed #31487 -- Added precision argument to Round().

Nick Pope 4 年之前
父节点
当前提交
2f13c476ab

+ 6 - 0
django/db/backends/sqlite3/features.py

@@ -65,6 +65,12 @@ class DatabaseFeatures(BaseDatabaseFeatures):
             "SQLite doesn't have a constraint.": {
             "SQLite doesn't have a constraint.": {
                 'model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values',
                 'model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values',
             },
             },
+            "SQLite doesn't support negative precision for ROUND().": {
+                'db_functions.math.test_round.RoundTests.test_null_with_negative_precision',
+                'db_functions.math.test_round.RoundTests.test_decimal_with_negative_precision',
+                'db_functions.math.test_round.RoundTests.test_float_with_negative_precision',
+                'db_functions.math.test_round.RoundTests.test_integer_with_negative_precision',
+            },
         }
         }
         if Database.sqlite_version_info < (3, 27):
         if Database.sqlite_version_info < (3, 27):
             skips.update({
             skips.update({

+ 16 - 2
django/db/models/functions/math.py

@@ -1,6 +1,6 @@
 import math
 import math
 
 
-from django.db.models.expressions import Func
+from django.db.models.expressions import Func, Value
 from django.db.models.fields import FloatField, IntegerField
 from django.db.models.fields import FloatField, IntegerField
 from django.db.models.functions import Cast
 from django.db.models.functions import Cast
 from django.db.models.functions.mixins import (
 from django.db.models.functions.mixins import (
@@ -158,9 +158,23 @@ class Random(NumericOutputFieldMixin, Func):
         return []
         return []
 
 
 
 
-class Round(Transform):
+class Round(FixDecimalInputMixin, Transform):
     function = 'ROUND'
     function = 'ROUND'
     lookup_name = 'round'
     lookup_name = 'round'
+    arity = None  # Override Transform's arity=1 to enable passing precision.
+
+    def __init__(self, expression, precision=0, **extra):
+        super().__init__(expression, precision, **extra)
+
+    def as_sqlite(self, compiler, connection, **extra_context):
+        precision = self.get_source_expressions()[1]
+        if isinstance(precision, Value) and precision.value < 0:
+            raise ValueError('SQLite does not support negative precision.')
+        return super().as_sqlite(compiler, connection, **extra_context)
+
+    def _resolve_output_field(self):
+        source = self.get_source_expressions()[0]
+        return source.output_field
 
 
 
 
 class Sign(Transform):
 class Sign(Transform):

+ 10 - 5
docs/ref/models/database-functions.txt

@@ -1147,18 +1147,19 @@ Returns a random value in the range ``0.0 ≤ x < 1.0``.
 ``Round``
 ``Round``
 ---------
 ---------
 
 
-.. class:: Round(expression, **extra)
+.. class:: Round(expression, precision=0, **extra)
 
 
-Rounds a numeric field or expression to the nearest integer. Whether half
+Rounds a numeric field or expression to ``precision`` (must be an integer)
+decimal places. By default, it rounds to the nearest integer. Whether half
 values are rounded up or down depends on the database.
 values are rounded up or down depends on the database.
 
 
 Usage example::
 Usage example::
 
 
     >>> from django.db.models.functions import Round
     >>> from django.db.models.functions import Round
-    >>> Vector.objects.create(x=5.4, y=-2.3)
-    >>> vector = Vector.objects.annotate(x_r=Round('x'), y_r=Round('y')).get()
+    >>> Vector.objects.create(x=5.4, y=-2.37)
+    >>> vector = Vector.objects.annotate(x_r=Round('x'), y_r=Round('y', precision=1)).get()
     >>> vector.x_r, vector.y_r
     >>> vector.x_r, vector.y_r
-    (5.0, -2.0)
+    (5.0, -2.4)
 
 
 It can also be registered as a transform. For example::
 It can also be registered as a transform. For example::
 
 
@@ -1168,6 +1169,10 @@ It can also be registered as a transform. For example::
     >>> # Get vectors whose round() is less than 20
     >>> # Get vectors whose round() is less than 20
     >>> vectors = Vector.objects.filter(x__round__lt=20, y__round__lt=20)
     >>> vectors = Vector.objects.filter(x__round__lt=20, y__round__lt=20)
 
 
+.. versionchanged:: 4.0
+
+    The ``precision`` argument was added.
+
 ``Sign``
 ``Sign``
 --------
 --------
 
 

+ 4 - 0
docs/releases/4.0.txt

@@ -222,6 +222,10 @@ Models
   whether the queryset contains the given object. This tries to perform the
   whether the queryset contains the given object. This tries to perform the
   query in the simplest and fastest way possible.
   query in the simplest and fastest way possible.
 
 
+* The new ``precision`` argument of the
+  :class:`Round() <django.db.models.functions.Round>` database function allows
+  specifying the number of decimal places after rounding.
+
 Requests and Responses
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 ~~~~~~~~~~~~~~~~~~~~~~
 
 

+ 77 - 1
tests/db_functions/math/test_round.py

@@ -1,7 +1,9 @@
+import unittest
 from decimal import Decimal
 from decimal import Decimal
 
 
+from django.db import connection
 from django.db.models import DecimalField
 from django.db.models import DecimalField
-from django.db.models.functions import Round
+from django.db.models.functions import Pi, Round
 from django.test import TestCase
 from django.test import TestCase
 from django.test.utils import register_lookup
 from django.test.utils import register_lookup
 
 
@@ -15,6 +17,16 @@ class RoundTests(TestCase):
         obj = IntegerModel.objects.annotate(null_round=Round('normal')).first()
         obj = IntegerModel.objects.annotate(null_round=Round('normal')).first()
         self.assertIsNone(obj.null_round)
         self.assertIsNone(obj.null_round)
 
 
+    def test_null_with_precision(self):
+        IntegerModel.objects.create()
+        obj = IntegerModel.objects.annotate(null_round=Round('normal', 5)).first()
+        self.assertIsNone(obj.null_round)
+
+    def test_null_with_negative_precision(self):
+        IntegerModel.objects.create()
+        obj = IntegerModel.objects.annotate(null_round=Round('normal', -1)).first()
+        self.assertIsNone(obj.null_round)
+
     def test_decimal(self):
     def test_decimal(self):
         DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6'))
         DecimalModel.objects.create(n1=Decimal('-12.9'), n2=Decimal('0.6'))
         obj = DecimalModel.objects.annotate(n1_round=Round('n1'), n2_round=Round('n2')).first()
         obj = DecimalModel.objects.annotate(n1_round=Round('n1'), n2_round=Round('n2')).first()
@@ -23,6 +35,23 @@ class RoundTests(TestCase):
         self.assertAlmostEqual(obj.n1_round, obj.n1, places=0)
         self.assertAlmostEqual(obj.n1_round, obj.n1, places=0)
         self.assertAlmostEqual(obj.n2_round, obj.n2, places=0)
         self.assertAlmostEqual(obj.n2_round, obj.n2, places=0)
 
 
+    def test_decimal_with_precision(self):
+        DecimalModel.objects.create(n1=Decimal('-5.75'), n2=Pi())
+        obj = DecimalModel.objects.annotate(
+            n1_round=Round('n1', 1),
+            n2_round=Round('n2', 5),
+        ).first()
+        self.assertIsInstance(obj.n1_round, Decimal)
+        self.assertIsInstance(obj.n2_round, Decimal)
+        self.assertAlmostEqual(obj.n1_round, obj.n1, places=1)
+        self.assertAlmostEqual(obj.n2_round, obj.n2, places=5)
+
+    def test_decimal_with_negative_precision(self):
+        DecimalModel.objects.create(n1=Decimal('365.25'))
+        obj = DecimalModel.objects.annotate(n1_round=Round('n1', -1)).first()
+        self.assertIsInstance(obj.n1_round, Decimal)
+        self.assertEqual(obj.n1_round, 370)
+
     def test_float(self):
     def test_float(self):
         FloatModel.objects.create(f1=-27.55, f2=0.55)
         FloatModel.objects.create(f1=-27.55, f2=0.55)
         obj = FloatModel.objects.annotate(f1_round=Round('f1'), f2_round=Round('f2')).first()
         obj = FloatModel.objects.annotate(f1_round=Round('f1'), f2_round=Round('f2')).first()
@@ -31,6 +60,23 @@ class RoundTests(TestCase):
         self.assertAlmostEqual(obj.f1_round, obj.f1, places=0)
         self.assertAlmostEqual(obj.f1_round, obj.f1, places=0)
         self.assertAlmostEqual(obj.f2_round, obj.f2, places=0)
         self.assertAlmostEqual(obj.f2_round, obj.f2, places=0)
 
 
+    def test_float_with_precision(self):
+        FloatModel.objects.create(f1=-5.75, f2=Pi())
+        obj = FloatModel.objects.annotate(
+            f1_round=Round('f1', 1),
+            f2_round=Round('f2', 5),
+        ).first()
+        self.assertIsInstance(obj.f1_round, float)
+        self.assertIsInstance(obj.f2_round, float)
+        self.assertAlmostEqual(obj.f1_round, obj.f1, places=1)
+        self.assertAlmostEqual(obj.f2_round, obj.f2, places=5)
+
+    def test_float_with_negative_precision(self):
+        FloatModel.objects.create(f1=365.25)
+        obj = FloatModel.objects.annotate(f1_round=Round('f1', -1)).first()
+        self.assertIsInstance(obj.f1_round, float)
+        self.assertEqual(obj.f1_round, 370)
+
     def test_integer(self):
     def test_integer(self):
         IntegerModel.objects.create(small=-20, normal=15, big=-1)
         IntegerModel.objects.create(small=-20, normal=15, big=-1)
         obj = IntegerModel.objects.annotate(
         obj = IntegerModel.objects.annotate(
@@ -45,9 +91,39 @@ class RoundTests(TestCase):
         self.assertAlmostEqual(obj.normal_round, obj.normal, places=0)
         self.assertAlmostEqual(obj.normal_round, obj.normal, places=0)
         self.assertAlmostEqual(obj.big_round, obj.big, places=0)
         self.assertAlmostEqual(obj.big_round, obj.big, places=0)
 
 
+    def test_integer_with_precision(self):
+        IntegerModel.objects.create(small=-5, normal=3, big=-100)
+        obj = IntegerModel.objects.annotate(
+            small_round=Round('small', 1),
+            normal_round=Round('normal', 5),
+            big_round=Round('big', 2),
+        ).first()
+        self.assertIsInstance(obj.small_round, int)
+        self.assertIsInstance(obj.normal_round, int)
+        self.assertIsInstance(obj.big_round, int)
+        self.assertAlmostEqual(obj.small_round, obj.small, places=1)
+        self.assertAlmostEqual(obj.normal_round, obj.normal, places=5)
+        self.assertAlmostEqual(obj.big_round, obj.big, places=2)
+
+    def test_integer_with_negative_precision(self):
+        IntegerModel.objects.create(normal=365)
+        obj = IntegerModel.objects.annotate(normal_round=Round('normal', -1)).first()
+        self.assertIsInstance(obj.normal_round, int)
+        self.assertEqual(obj.normal_round, 370)
+
     def test_transform(self):
     def test_transform(self):
         with register_lookup(DecimalField, Round):
         with register_lookup(DecimalField, Round):
             DecimalModel.objects.create(n1=Decimal('2.0'), n2=Decimal('0'))
             DecimalModel.objects.create(n1=Decimal('2.0'), n2=Decimal('0'))
             DecimalModel.objects.create(n1=Decimal('-1.0'), n2=Decimal('0'))
             DecimalModel.objects.create(n1=Decimal('-1.0'), n2=Decimal('0'))
             obj = DecimalModel.objects.filter(n1__round__gt=0).get()
             obj = DecimalModel.objects.filter(n1__round__gt=0).get()
             self.assertEqual(obj.n1, Decimal('2.0'))
             self.assertEqual(obj.n1, Decimal('2.0'))
+
+    @unittest.skipUnless(
+        connection.vendor == 'sqlite',
+        "SQLite doesn't support negative precision.",
+    )
+    def test_unsupported_negative_precision(self):
+        FloatModel.objects.create(f1=123.45)
+        msg = 'SQLite does not support negative precision.'
+        with self.assertRaisesMessage(ValueError, msg):
+            FloatModel.objects.annotate(value=Round('f1', -1)).first()

+ 1 - 1
tests/db_functions/migrations/0002_create_test_models.py

@@ -56,7 +56,7 @@ class Migration(migrations.Migration):
             name='DecimalModel',
             name='DecimalModel',
             fields=[
             fields=[
                 ('n1', models.DecimalField(decimal_places=2, max_digits=6)),
                 ('n1', models.DecimalField(decimal_places=2, max_digits=6)),
-                ('n2', models.DecimalField(decimal_places=2, max_digits=6)),
+                ('n2', models.DecimalField(decimal_places=7, max_digits=9, null=True, blank=True)),
             ],
             ],
         ),
         ),
         migrations.CreateModel(
         migrations.CreateModel(

+ 1 - 1
tests/db_functions/models.py

@@ -42,7 +42,7 @@ class DTModel(models.Model):
 
 
 class DecimalModel(models.Model):
 class DecimalModel(models.Model):
     n1 = models.DecimalField(decimal_places=2, max_digits=6)
     n1 = models.DecimalField(decimal_places=2, max_digits=6)
-    n2 = models.DecimalField(decimal_places=2, max_digits=6)
+    n2 = models.DecimalField(decimal_places=7, max_digits=9, null=True, blank=True)
 
 
 
 
 class IntegerModel(models.Model):
 class IntegerModel(models.Model):