Răsfoiți Sursa

Fixed #27147 -- Allowed specifying bounds of tuple inputs for non-discrete range fields.

Guilherme Martins Crocetti 3 ani în urmă
părinte
comite
fc565cb539

+ 37 - 2
django/contrib/postgres/fields/ranges.py

@@ -44,6 +44,10 @@ class RangeField(models.Field):
     empty_strings_allowed = False
 
     def __init__(self, *args, **kwargs):
+        if 'default_bounds' in kwargs:
+            raise TypeError(
+                f"Cannot use 'default_bounds' with {self.__class__.__name__}."
+            )
         # Initializing base_field here ensures that its model matches the model for self.
         if hasattr(self, 'base_field'):
             self.base_field = self.base_field()
@@ -112,6 +116,37 @@ class RangeField(models.Field):
         return super().formfield(**kwargs)
 
 
+CANONICAL_RANGE_BOUNDS = '[)'
+
+
+class ContinuousRangeField(RangeField):
+    """
+    Continuous range field. It allows specifying default bounds for list and
+    tuple inputs.
+    """
+
+    def __init__(self, *args, default_bounds=CANONICAL_RANGE_BOUNDS, **kwargs):
+        if default_bounds not in ('[)', '(]', '()', '[]'):
+            raise ValueError("default_bounds must be one of '[)', '(]', '()', or '[]'.")
+        self.default_bounds = default_bounds
+        super().__init__(*args, **kwargs)
+
+    def get_prep_value(self, value):
+        if isinstance(value, (list, tuple)):
+            return self.range_type(value[0], value[1], self.default_bounds)
+        return super().get_prep_value(value)
+
+    def formfield(self, **kwargs):
+        kwargs.setdefault('default_bounds', self.default_bounds)
+        return super().formfield(**kwargs)
+
+    def deconstruct(self):
+        name, path, args, kwargs = super().deconstruct()
+        if self.default_bounds and self.default_bounds != CANONICAL_RANGE_BOUNDS:
+            kwargs['default_bounds'] = self.default_bounds
+        return name, path, args, kwargs
+
+
 class IntegerRangeField(RangeField):
     base_field = models.IntegerField
     range_type = NumericRange
@@ -130,7 +165,7 @@ class BigIntegerRangeField(RangeField):
         return 'int8range'
 
 
-class DecimalRangeField(RangeField):
+class DecimalRangeField(ContinuousRangeField):
     base_field = models.DecimalField
     range_type = NumericRange
     form_field = forms.DecimalRangeField
@@ -139,7 +174,7 @@ class DecimalRangeField(RangeField):
         return 'numrange'
 
 
-class DateTimeRangeField(RangeField):
+class DateTimeRangeField(ContinuousRangeField):
     base_field = models.DateTimeField
     range_type = DateTimeTZRange
     form_field = forms.DateTimeRangeField

+ 4 - 1
django/contrib/postgres/forms/ranges.py

@@ -42,6 +42,9 @@ class BaseRangeField(forms.MultiValueField):
             kwargs['fields'] = [self.base_field(required=False), self.base_field(required=False)]
         kwargs.setdefault('required', False)
         kwargs.setdefault('require_all_fields', False)
+        self.range_kwargs = {}
+        if default_bounds := kwargs.pop('default_bounds', None):
+            self.range_kwargs = {'bounds': default_bounds}
         super().__init__(**kwargs)
 
     def prepare_value(self, value):
@@ -68,7 +71,7 @@ class BaseRangeField(forms.MultiValueField):
                 code='bound_ordering',
             )
         try:
-            range_value = self.range_type(lower, upper)
+            range_value = self.range_type(lower, upper, **self.range_kwargs)
         except TypeError:
             raise exceptions.ValidationError(
                 self.error_messages['invalid'],

+ 27 - 5
docs/ref/contrib/postgres/fields.txt

@@ -503,9 +503,9 @@ All of the range fields translate to :ref:`psycopg2 Range objects
 <psycopg2:adapt-range>` in Python, but also accept tuples as input if no bounds
 information is necessary. The default is lower bound included, upper bound
 excluded, that is ``[)`` (see the PostgreSQL documentation for details about
-`different bounds`_).
-
-.. _different bounds: https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-IO
+`different bounds`_). The default bounds can be changed for non-discrete range
+fields (:class:`.DateTimeRangeField` and :class:`.DecimalRangeField`) by using
+the ``default_bounds`` argument.
 
 ``IntegerRangeField``
 ---------------------
@@ -538,23 +538,43 @@ excluded, that is ``[)`` (see the PostgreSQL documentation for details about
 ``DecimalRangeField``
 ---------------------
 
-.. class:: DecimalRangeField(**options)
+.. class:: DecimalRangeField(default_bounds='[)', **options)
 
     Stores a range of floating point values. Based on a
     :class:`~django.db.models.DecimalField`. Represented by a ``numrange`` in
     the database and a :class:`~psycopg2:psycopg2.extras.NumericRange` in
     Python.
 
+    .. attribute:: DecimalRangeField.default_bounds
+
+        .. versionadded:: 4.1
+
+        Optional. The value of ``bounds`` for list and tuple inputs. The
+        default is lower bound included, upper bound excluded, that is ``[)``
+        (see the PostgreSQL documentation for details about
+        `different bounds`_). ``default_bounds`` is not used for
+        :class:`~psycopg2:psycopg2.extras.NumericRange` inputs.
+
 ``DateTimeRangeField``
 ----------------------
 
-.. class:: DateTimeRangeField(**options)
+.. class:: DateTimeRangeField(default_bounds='[)', **options)
 
     Stores a range of timestamps. Based on a
     :class:`~django.db.models.DateTimeField`. Represented by a ``tstzrange`` in
     the database and a :class:`~psycopg2:psycopg2.extras.DateTimeTZRange` in
     Python.
 
+    .. attribute:: DateTimeRangeField.default_bounds
+
+        .. versionadded:: 4.1
+
+        Optional. The value of ``bounds`` for list and tuple inputs. The
+        default is lower bound included, upper bound excluded, that is ``[)``
+        (see the PostgreSQL documentation for details about
+        `different bounds`_). ``default_bounds`` is not used for
+        :class:`~psycopg2:psycopg2.extras.DateTimeTZRange` inputs.
+
 ``DateRangeField``
 ------------------
 
@@ -884,3 +904,5 @@ used with a custom range functions that expected boundaries, for example to
 define :class:`~django.contrib.postgres.constraints.ExclusionConstraint`. See
 `the PostgreSQL documentation for the full details <https://www.postgresql.org/
 docs/current/rangetypes.html#RANGETYPES-INCLUSIVITY>`_.
+
+.. _different bounds: https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-IO

+ 6 - 0
docs/releases/4.1.txt

@@ -76,6 +76,12 @@ Minor features
   supports covering exclusion constraints using SP-GiST indexes on PostgreSQL
   14+.
 
+* The new ``default_bounds`` attribute of :attr:`DateTimeRangeField
+  <django.contrib.postgres.fields.DateTimeRangeField.default_bounds>` and
+  :attr:`DecimalRangeField
+  <django.contrib.postgres.fields.DecimalRangeField.default_bounds>` allows
+  specifying bounds for list and tuple inputs.
+
 :mod:`django.contrib.redirects`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

+ 11 - 2
tests/postgres_tests/fields.py

@@ -26,14 +26,23 @@ except ImportError:
             })
             return name, path, args, kwargs
 
+    class DummyContinuousRangeField(models.Field):
+        def __init__(self, *args, default_bounds='[)', **kwargs):
+            super().__init__(**kwargs)
+
+        def deconstruct(self):
+            name, path, args, kwargs = super().deconstruct()
+            kwargs['default_bounds'] = '[)'
+            return name, path, args, kwargs
+
     ArrayField = DummyArrayField
     BigIntegerRangeField = models.Field
     CICharField = models.Field
     CIEmailField = models.Field
     CITextField = models.Field
     DateRangeField = models.Field
-    DateTimeRangeField = models.Field
-    DecimalRangeField = models.Field
+    DateTimeRangeField = DummyContinuousRangeField
+    DecimalRangeField = DummyContinuousRangeField
     HStoreField = models.Field
     IntegerRangeField = models.Field
     SearchVector = models.Expression

+ 1 - 0
tests/postgres_tests/migrations/0002_create_test_models.py

@@ -249,6 +249,7 @@ class Migration(migrations.Migration):
                 ('decimals', DecimalRangeField(null=True, blank=True)),
                 ('timestamps', DateTimeRangeField(null=True, blank=True)),
                 ('timestamps_inner', DateTimeRangeField(null=True, blank=True)),
+                ('timestamps_closed_bounds', DateTimeRangeField(null=True, blank=True, default_bounds='[]')),
                 ('dates', DateRangeField(null=True, blank=True)),
                 ('dates_inner', DateRangeField(null=True, blank=True)),
             ],

+ 3 - 0
tests/postgres_tests/models.py

@@ -135,6 +135,9 @@ class RangesModel(PostgreSQLModel):
     decimals = DecimalRangeField(blank=True, null=True)
     timestamps = DateTimeRangeField(blank=True, null=True)
     timestamps_inner = DateTimeRangeField(blank=True, null=True)
+    timestamps_closed_bounds = DateTimeRangeField(
+        blank=True, null=True, default_bounds='[]',
+    )
     dates = DateRangeField(blank=True, null=True)
     dates_inner = DateRangeField(blank=True, null=True)
 

+ 92 - 3
tests/postgres_tests/test_ranges.py

@@ -50,6 +50,41 @@ class BasicTests(PostgreSQLSimpleTestCase):
                 instance = Model(field=value)
                 self.assertEqual(instance.get_field_display(), display)
 
+    def test_discrete_range_fields_unsupported_default_bounds(self):
+        discrete_range_types = [
+            pg_fields.BigIntegerRangeField,
+            pg_fields.IntegerRangeField,
+            pg_fields.DateRangeField,
+        ]
+        for field_type in discrete_range_types:
+            msg = f"Cannot use 'default_bounds' with {field_type.__name__}."
+            with self.assertRaisesMessage(TypeError, msg):
+                field_type(choices=[((51, 100), '51-100')], default_bounds='[]')
+
+    def test_continuous_range_fields_default_bounds(self):
+        continuous_range_types = [
+            pg_fields.DecimalRangeField,
+            pg_fields.DateTimeRangeField,
+        ]
+        for field_type in continuous_range_types:
+            field = field_type(choices=[((51, 100), '51-100')], default_bounds='[]')
+            self.assertEqual(field.default_bounds, '[]')
+
+    def test_invalid_default_bounds(self):
+        tests = [')]', ')[', '](', '])', '([', '[(', 'x', '', None]
+        msg = "default_bounds must be one of '[)', '(]', '()', or '[]'."
+        for invalid_bounds in tests:
+            with self.assertRaisesMessage(ValueError, msg):
+                pg_fields.DecimalRangeField(default_bounds=invalid_bounds)
+
+    def test_deconstruct(self):
+        field = pg_fields.DecimalRangeField()
+        *_, kwargs = field.deconstruct()
+        self.assertEqual(kwargs, {})
+        field = pg_fields.DecimalRangeField(default_bounds='[]')
+        *_, kwargs = field.deconstruct()
+        self.assertEqual(kwargs, {'default_bounds': '[]'})
+
 
 class TestSaveLoad(PostgreSQLTestCase):
 
@@ -83,6 +118,19 @@ class TestSaveLoad(PostgreSQLTestCase):
         loaded = RangesModel.objects.get()
         self.assertEqual(NumericRange(0, 10), loaded.ints)
 
+    def test_tuple_range_with_default_bounds(self):
+        range_ = (timezone.now(), timezone.now() + datetime.timedelta(hours=1))
+        RangesModel.objects.create(timestamps_closed_bounds=range_, timestamps=range_)
+        loaded = RangesModel.objects.get()
+        self.assertEqual(
+            loaded.timestamps_closed_bounds,
+            DateTimeTZRange(range_[0], range_[1], '[]'),
+        )
+        self.assertEqual(
+            loaded.timestamps,
+            DateTimeTZRange(range_[0], range_[1], '[)'),
+        )
+
     def test_range_object_boundaries(self):
         r = NumericRange(0, 10, '[]')
         instance = RangesModel(decimals=r)
@@ -91,6 +139,16 @@ class TestSaveLoad(PostgreSQLTestCase):
         self.assertEqual(r, loaded.decimals)
         self.assertIn(10, loaded.decimals)
 
+    def test_range_object_boundaries_range_with_default_bounds(self):
+        range_ = DateTimeTZRange(
+            timezone.now(),
+            timezone.now() + datetime.timedelta(hours=1),
+            bounds='()',
+        )
+        RangesModel.objects.create(timestamps_closed_bounds=range_)
+        loaded = RangesModel.objects.get()
+        self.assertEqual(loaded.timestamps_closed_bounds, range_)
+
     def test_unbounded(self):
         r = NumericRange(None, None, '()')
         instance = RangesModel(decimals=r)
@@ -478,6 +536,8 @@ class TestSerialization(PostgreSQLSimpleTestCase):
         '"bigints": null, "timestamps": "{\\"upper\\": \\"2014-02-02T12:12:12+00:00\\", '
         '\\"lower\\": \\"2014-01-01T00:00:00+00:00\\", \\"bounds\\": \\"[)\\"}", '
         '"timestamps_inner": null, '
+        '"timestamps_closed_bounds": "{\\"upper\\": \\"2014-02-02T12:12:12+00:00\\", '
+        '\\"lower\\": \\"2014-01-01T00:00:00+00:00\\", \\"bounds\\": \\"()\\"}", '
         '"dates": "{\\"upper\\": \\"2014-02-02\\", \\"lower\\": \\"2014-01-01\\", \\"bounds\\": \\"[)\\"}", '
         '"dates_inner": null }, '
         '"model": "postgres_tests.rangesmodel", "pk": null}]'
@@ -492,15 +552,19 @@ class TestSerialization(PostgreSQLSimpleTestCase):
         instance = RangesModel(
             ints=NumericRange(0, 10), decimals=NumericRange(empty=True),
             timestamps=DateTimeTZRange(self.lower_dt, self.upper_dt),
+            timestamps_closed_bounds=DateTimeTZRange(
+                self.lower_dt, self.upper_dt, bounds='()',
+            ),
             dates=DateRange(self.lower_date, self.upper_date),
         )
         data = serializers.serialize('json', [instance])
         dumped = json.loads(data)
-        for field in ('ints', 'dates', 'timestamps'):
+        for field in ('ints', 'dates', 'timestamps', 'timestamps_closed_bounds'):
             dumped[0]['fields'][field] = json.loads(dumped[0]['fields'][field])
         check = json.loads(self.test_data)
-        for field in ('ints', 'dates', 'timestamps'):
+        for field in ('ints', 'dates', 'timestamps', 'timestamps_closed_bounds'):
             check[0]['fields'][field] = json.loads(check[0]['fields'][field])
+
         self.assertEqual(dumped, check)
 
     def test_loading(self):
@@ -510,6 +574,10 @@ class TestSerialization(PostgreSQLSimpleTestCase):
         self.assertIsNone(instance.bigints)
         self.assertEqual(instance.dates, DateRange(self.lower_date, self.upper_date))
         self.assertEqual(instance.timestamps, DateTimeTZRange(self.lower_dt, self.upper_dt))
+        self.assertEqual(
+            instance.timestamps_closed_bounds,
+            DateTimeTZRange(self.lower_dt, self.upper_dt, bounds='()'),
+        )
 
     def test_serialize_range_with_null(self):
         instance = RangesModel(ints=NumericRange(None, 10))
@@ -886,26 +954,47 @@ class TestFormField(PostgreSQLSimpleTestCase):
         model_field = pg_fields.IntegerRangeField()
         form_field = model_field.formfield()
         self.assertIsInstance(form_field, pg_forms.IntegerRangeField)
+        self.assertEqual(form_field.range_kwargs, {})
 
     def test_model_field_formfield_biginteger(self):
         model_field = pg_fields.BigIntegerRangeField()
         form_field = model_field.formfield()
         self.assertIsInstance(form_field, pg_forms.IntegerRangeField)
+        self.assertEqual(form_field.range_kwargs, {})
 
     def test_model_field_formfield_float(self):
-        model_field = pg_fields.DecimalRangeField()
+        model_field = pg_fields.DecimalRangeField(default_bounds='()')
         form_field = model_field.formfield()
         self.assertIsInstance(form_field, pg_forms.DecimalRangeField)
+        self.assertEqual(form_field.range_kwargs, {'bounds': '()'})
 
     def test_model_field_formfield_date(self):
         model_field = pg_fields.DateRangeField()
         form_field = model_field.formfield()
         self.assertIsInstance(form_field, pg_forms.DateRangeField)
+        self.assertEqual(form_field.range_kwargs, {})
 
     def test_model_field_formfield_datetime(self):
         model_field = pg_fields.DateTimeRangeField()
         form_field = model_field.formfield()
         self.assertIsInstance(form_field, pg_forms.DateTimeRangeField)
+        self.assertEqual(
+            form_field.range_kwargs,
+            {'bounds': pg_fields.ranges.CANONICAL_RANGE_BOUNDS},
+        )
+
+    def test_model_field_formfield_datetime_default_bounds(self):
+        model_field = pg_fields.DateTimeRangeField(default_bounds='[]')
+        form_field = model_field.formfield()
+        self.assertIsInstance(form_field, pg_forms.DateTimeRangeField)
+        self.assertEqual(form_field.range_kwargs, {'bounds': '[]'})
+
+    def test_model_field_with_default_bounds(self):
+        field = pg_forms.DateTimeRangeField(default_bounds='[]')
+        value = field.clean(['2014-01-01 00:00:00', '2014-02-03 12:13:14'])
+        lower = datetime.datetime(2014, 1, 1, 0, 0, 0)
+        upper = datetime.datetime(2014, 2, 3, 12, 13, 14)
+        self.assertEqual(value, DateTimeTZRange(lower, upper, '[]'))
 
     def test_has_changed(self):
         for field, value in (