2
0
Эх сурвалжийг харах

Fixed #24636 -- Added model field validation for decimal places and max digits.

Iulia Chiriac 10 жил өмнө
parent
commit
75ed590032

+ 69 - 0
django/core/validators.py

@@ -346,3 +346,72 @@ class MaxLengthValidator(BaseValidator):
         'Ensure this value has at most %(limit_value)d characters (it has %(show_value)d).',
         'limit_value')
     code = 'max_length'
+
+
+@deconstructible
+class DecimalValidator(object):
+    """
+    Validate that the input does not exceed the maximum number of digits
+    expected, otherwise raise ValidationError.
+    """
+    messages = {
+        'max_digits': ungettext_lazy(
+            'Ensure that there are no more than %(max)s digit in total.',
+            'Ensure that there are no more than %(max)s digits in total.',
+            'max'
+        ),
+        'max_decimal_places': ungettext_lazy(
+            'Ensure that there are no more than %(max)s decimal place.',
+            'Ensure that there are no more than %(max)s decimal places.',
+            'max'
+        ),
+        'max_whole_digits': ungettext_lazy(
+            'Ensure that there are no more than %(max)s digit before the decimal point.',
+            'Ensure that there are no more than %(max)s digits before the decimal point.',
+            'max'
+        ),
+    }
+
+    def __init__(self, max_digits, decimal_places):
+        self.max_digits = max_digits
+        self.decimal_places = decimal_places
+
+    def __call__(self, value):
+        digit_tuple, exponent = value.as_tuple()[1:]
+        decimals = abs(exponent)
+        # digit_tuple doesn't include any leading zeros.
+        digits = len(digit_tuple)
+        if decimals > digits:
+            # We have leading zeros up to or past the decimal point. Count
+            # everything past the decimal point as a digit. We do not count
+            # 0 before the decimal point as a digit since that would mean
+            # we would not allow max_digits = decimal_places.
+            digits = decimals
+        whole_digits = digits - decimals
+
+        if self.max_digits is not None and digits > self.max_digits:
+            raise ValidationError(
+                self.messages['max_digits'],
+                code='max_digits',
+                params={'max': self.max_digits},
+            )
+        if self.decimal_places is not None and decimals > self.decimal_places:
+            raise ValidationError(
+                self.messages['max_decimal_places'],
+                code='max_decimal_places',
+                params={'max': self.decimal_places},
+            )
+        if (self.max_digits is not None and self.decimal_places is not None
+                and whole_digits > (self.max_digits - self.decimal_places)):
+            raise ValidationError(
+                self.messages['max_whole_digits'],
+                code='max_whole_digits',
+                params={'max': (self.max_digits - self.decimal_places)},
+            )
+
+    def __eq__(self, other):
+        return (
+            isinstance(other, self.__class__) and
+            self.max_digits == other.max_digits and
+            self.decimal_places == other.decimal_places
+        )

+ 6 - 0
django/db/models/fields/__init__.py

@@ -1578,6 +1578,12 @@ class DecimalField(Field):
             ]
         return []
 
+    @cached_property
+    def validators(self):
+        return super(DecimalField, self).validators + [
+            validators.DecimalValidator(self.max_digits, self.decimal_places)
+        ]
+
     def deconstruct(self):
         name, path, args, kwargs = super(DecimalField, self).deconstruct()
         if self.max_digits is not None:

+ 1 - 44
django/forms/fields.py

@@ -334,23 +334,12 @@ class FloatField(IntegerField):
 class DecimalField(IntegerField):
     default_error_messages = {
         'invalid': _('Enter a number.'),
-        'max_digits': ungettext_lazy(
-            'Ensure that there are no more than %(max)s digit in total.',
-            'Ensure that there are no more than %(max)s digits in total.',
-            'max'),
-        'max_decimal_places': ungettext_lazy(
-            'Ensure that there are no more than %(max)s decimal place.',
-            'Ensure that there are no more than %(max)s decimal places.',
-            'max'),
-        'max_whole_digits': ungettext_lazy(
-            'Ensure that there are no more than %(max)s digit before the decimal point.',
-            'Ensure that there are no more than %(max)s digits before the decimal point.',
-            'max'),
     }
 
     def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs):
         self.max_digits, self.decimal_places = max_digits, decimal_places
         super(DecimalField, self).__init__(max_value, min_value, *args, **kwargs)
+        self.validators.append(validators.DecimalValidator(max_digits, decimal_places))
 
     def to_python(self, value):
         """
@@ -379,38 +368,6 @@ class DecimalField(IntegerField):
         # isn't equal to itself, so we can use this to identify NaN
         if value != value or value == Decimal("Inf") or value == Decimal("-Inf"):
             raise ValidationError(self.error_messages['invalid'], code='invalid')
-        sign, digittuple, exponent = value.as_tuple()
-        decimals = abs(exponent)
-        # digittuple doesn't include any leading zeros.
-        digits = len(digittuple)
-        if decimals > digits:
-            # We have leading zeros up to or past the decimal point.  Count
-            # everything past the decimal point as a digit.  We do not count
-            # 0 before the decimal point as a digit since that would mean
-            # we would not allow max_digits = decimal_places.
-            digits = decimals
-        whole_digits = digits - decimals
-
-        if self.max_digits is not None and digits > self.max_digits:
-            raise ValidationError(
-                self.error_messages['max_digits'],
-                code='max_digits',
-                params={'max': self.max_digits},
-            )
-        if self.decimal_places is not None and decimals > self.decimal_places:
-            raise ValidationError(
-                self.error_messages['max_decimal_places'],
-                code='max_decimal_places',
-                params={'max': self.decimal_places},
-            )
-        if (self.max_digits is not None and self.decimal_places is not None
-                and whole_digits > (self.max_digits - self.decimal_places)):
-            raise ValidationError(
-                self.error_messages['max_whole_digits'],
-                code='max_whole_digits',
-                params={'max': (self.max_digits - self.decimal_places)},
-            )
-        return value
 
     def widget_attrs(self, widget):
         attrs = super(DecimalField, self).widget_attrs(widget)

+ 16 - 0
docs/ref/validators.txt

@@ -281,3 +281,19 @@ to, or in lieu of custom ``field.clean()`` methods.
     .. versionchanged:: 1.8
 
        The ``message`` parameter was added.
+
+``DecimalValidator``
+--------------------
+
+.. class:: DecimalValidator(max_digits, decimal_places)
+
+    .. versionadded:: 1.9
+
+    Raises :exc:`~django.core.exceptions.ValidationError` with the following
+    codes:
+
+    - ``'max_digits'`` if the number of digits is larger than ``max_digits``.
+    - ``'max_decimal_places'`` if the number of decimals is larger than
+      ``decimal_places``.
+    - ``'max_whole_digits'`` if the number of whole digits is larger than
+      the difference between ``max_digits`` and ``decimal_places``.

+ 18 - 0
tests/model_fields/tests.py

@@ -165,6 +165,24 @@ class DecimalFieldTests(test.TestCase):
         # This should not crash. That counts as a win for our purposes.
         Foo.objects.filter(d__gte=100000000000)
 
+    def test_max_digits_validation(self):
+        field = models.DecimalField(max_digits=2)
+        expected_message = validators.DecimalValidator.messages['max_digits'] % {'max': 2}
+        with self.assertRaisesMessage(ValidationError, expected_message):
+            field.clean(100, None)
+
+    def test_max_decimal_places_validation(self):
+        field = models.DecimalField(decimal_places=1)
+        expected_message = validators.DecimalValidator.messages['max_decimal_places'] % {'max': 1}
+        with self.assertRaisesMessage(ValidationError, expected_message):
+            field.clean(Decimal('0.99'), None)
+
+    def test_max_whole_digits_validation(self):
+        field = models.DecimalField(max_digits=3, decimal_places=1)
+        expected_message = validators.DecimalValidator.messages['max_whole_digits'] % {'max': 2}
+        with self.assertRaisesMessage(ValidationError, expected_message):
+            field.clean(Decimal('999'), None)
+
 
 class ForeignKeyTests(test.TestCase):
     def test_callable_default(self):

+ 24 - 5
tests/validators/tests.py

@@ -10,11 +10,12 @@ from unittest import TestCase
 
 from django.core.exceptions import ValidationError
 from django.core.validators import (
-    BaseValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
-    MinLengthValidator, MinValueValidator, RegexValidator, URLValidator,
-    int_list_validator, validate_comma_separated_integer_list, validate_email,
-    validate_integer, validate_ipv4_address, validate_ipv6_address,
-    validate_ipv46_address, validate_slug, validate_unicode_slug,
+    BaseValidator, DecimalValidator, EmailValidator, MaxLengthValidator,
+    MaxValueValidator, MinLengthValidator, MinValueValidator, RegexValidator,
+    URLValidator, int_list_validator, validate_comma_separated_integer_list,
+    validate_email, validate_integer, validate_ipv4_address,
+    validate_ipv6_address, validate_ipv46_address, validate_slug,
+    validate_unicode_slug,
 )
 from django.test import SimpleTestCase
 from django.test.utils import str_prefix
@@ -401,3 +402,21 @@ class TestValidatorEquality(TestCase):
             MinValueValidator(45),
             MinValueValidator(11),
         )
+
+    def test_decimal_equality(self):
+        self.assertEqual(
+            DecimalValidator(1, 2),
+            DecimalValidator(1, 2),
+        )
+        self.assertNotEqual(
+            DecimalValidator(1, 2),
+            DecimalValidator(1, 1),
+        )
+        self.assertNotEqual(
+            DecimalValidator(1, 2),
+            DecimalValidator(2, 2),
+        )
+        self.assertNotEqual(
+            DecimalValidator(1, 2),
+            MinValueValidator(11),
+        )