Explorar o código

Fixed #24001 -- Added range fields for PostgreSQL.

Added support for PostgreSQL range types to contrib.postgres.

- 5 new model fields
- 4 new form fields
- New validators
- Uses psycopg2's range type implementation in python
Marc Tamlyn %!s(int64=10) %!d(string=hai) anos
pai
achega
48ad288679

+ 1 - 0
django/contrib/postgres/fields/__init__.py

@@ -1,2 +1,3 @@
 from .array import *  # NOQA
 from .array import *  # NOQA
 from .hstore import *  # NOQA
 from .hstore import *  # NOQA
+from .ranges import *  # NOQA

+ 156 - 0
django/contrib/postgres/fields/ranges.py

@@ -0,0 +1,156 @@
+import json
+
+from django.contrib.postgres import lookups, forms
+from django.db import models
+from django.utils import six
+
+from psycopg2.extras import Range, NumericRange, DateRange, DateTimeTZRange
+
+
+__all__ = [
+    'RangeField', 'IntegerRangeField', 'BigIntegerRangeField',
+    'FloatRangeField', 'DateTimeRangeField', 'DateRangeField',
+]
+
+
+class RangeField(models.Field):
+    empty_strings_allowed = False
+
+    def get_prep_value(self, value):
+        if value is None:
+            return None
+        elif isinstance(value, Range):
+            return value
+        elif isinstance(value, (list, tuple)):
+            return self.range_type(value[0], value[1])
+        return value
+
+    def to_python(self, value):
+        if isinstance(value, six.string_types):
+            value = self.range_type(**json.loads(value))
+        elif isinstance(value, (list, tuple)):
+            value = self.range_type(value[0], value[1])
+        return value
+
+    def value_to_string(self, obj):
+        value = self._get_val_from_obj(obj)
+        if value is None:
+            return None
+        if value.isempty:
+            return json.dumps({"empty": True})
+        return json.dumps({
+            "lower": value.lower,
+            "upper": value.upper,
+            "bounds": value._bounds,
+        })
+
+    def formfield(self, **kwargs):
+        kwargs.setdefault('form_class', self.form_field)
+        return super(RangeField, self).formfield(**kwargs)
+
+
+class IntegerRangeField(RangeField):
+    base_field = models.IntegerField()
+    range_type = NumericRange
+    form_field = forms.IntegerRangeField
+
+    def db_type(self, connection):
+        return 'int4range'
+
+
+class BigIntegerRangeField(RangeField):
+    base_field = models.BigIntegerField()
+    range_type = NumericRange
+    form_field = forms.IntegerRangeField
+
+    def db_type(self, connection):
+        return 'int8range'
+
+
+class FloatRangeField(RangeField):
+    base_field = models.FloatField()
+    range_type = NumericRange
+    form_field = forms.FloatRangeField
+
+    def db_type(self, connection):
+        return 'numrange'
+
+
+class DateTimeRangeField(RangeField):
+    base_field = models.DateTimeField()
+    range_type = DateTimeTZRange
+    form_field = forms.DateTimeRangeField
+
+    def db_type(self, connection):
+        return 'tstzrange'
+
+
+class DateRangeField(RangeField):
+    base_field = models.DateField()
+    range_type = DateRange
+    form_field = forms.DateRangeField
+
+    def db_type(self, connection):
+        return 'daterange'
+
+
+RangeField.register_lookup(lookups.DataContains)
+RangeField.register_lookup(lookups.ContainedBy)
+RangeField.register_lookup(lookups.Overlap)
+
+
+@RangeField.register_lookup
+class FullyLessThan(lookups.PostgresSimpleLookup):
+    lookup_name = 'fully_lt'
+    operator = '<<'
+
+
+@RangeField.register_lookup
+class FullGreaterThan(lookups.PostgresSimpleLookup):
+    lookup_name = 'fully_gt'
+    operator = '>>'
+
+
+@RangeField.register_lookup
+class NotLessThan(lookups.PostgresSimpleLookup):
+    lookup_name = 'not_lt'
+    operator = '&>'
+
+
+@RangeField.register_lookup
+class NotGreaterThan(lookups.PostgresSimpleLookup):
+    lookup_name = 'not_gt'
+    operator = '&<'
+
+
+@RangeField.register_lookup
+class AdjacentToLookup(lookups.PostgresSimpleLookup):
+    lookup_name = 'adjacent_to'
+    operator = '-|-'
+
+
+@RangeField.register_lookup
+class RangeStartsWith(lookups.FunctionTransform):
+    lookup_name = 'startswith'
+    function = 'lower'
+
+    @property
+    def output_field(self):
+        return self.lhs.output_field.base_field
+
+
+@RangeField.register_lookup
+class RangeEndsWith(lookups.FunctionTransform):
+    lookup_name = 'endswith'
+    function = 'upper'
+
+    @property
+    def output_field(self):
+        return self.lhs.output_field.base_field
+
+
+@RangeField.register_lookup
+class IsEmpty(lookups.FunctionTransform):
+    lookup_name = 'isempty'
+    function = 'isempty'
+    output_field = models.BooleanField()

+ 1 - 0
django/contrib/postgres/forms/__init__.py

@@ -1,2 +1,3 @@
 from .array import *  # NOQA
 from .array import *  # NOQA
 from .hstore import *  # NOQA
 from .hstore import *  # NOQA
+from .ranges import *  # NOQA

+ 69 - 0
django/contrib/postgres/forms/ranges.py

@@ -0,0 +1,69 @@
+from django.core import exceptions
+from django import forms
+from django.utils.translation import ugettext_lazy as _
+
+from psycopg2.extras import NumericRange, DateRange, DateTimeTZRange
+
+
+__all__ = ['IntegerRangeField', 'FloatRangeField', 'DateTimeRangeField', 'DateRangeField']
+
+
+class BaseRangeField(forms.MultiValueField):
+    default_error_messages = {
+        'invalid': _('Enter two valid values.'),
+        'bound_ordering': _('The start of the range must not exceed the end of the range.'),
+    }
+
+    def __init__(self, **kwargs):
+        widget = forms.MultiWidget([self.base_field.widget, self.base_field.widget])
+        kwargs.setdefault('widget', widget)
+        kwargs.setdefault('fields', [self.base_field(required=False), self.base_field(required=False)])
+        kwargs.setdefault('required', False)
+        kwargs.setdefault('require_all_fields', False)
+        super(BaseRangeField, self).__init__(**kwargs)
+
+    def prepare_value(self, value):
+        if isinstance(value, self.range_type):
+            return [value.lower, value.upper]
+        if value is None:
+            return [None, None]
+        return value
+
+    def compress(self, values):
+        if not values:
+            return None
+        lower, upper = values
+        if lower is not None and upper is not None and lower > upper:
+            raise exceptions.ValidationError(
+                self.error_messages['bound_ordering'],
+                code='bound_ordering',
+            )
+        try:
+            range_value = self.range_type(lower, upper)
+        except TypeError:
+            raise exceptions.ValidationError(
+                self.error_messages['invalid'],
+                code='invalid',
+            )
+        else:
+            return range_value
+
+
+class IntegerRangeField(BaseRangeField):
+    base_field = forms.IntegerField
+    range_type = NumericRange
+
+
+class FloatRangeField(BaseRangeField):
+    base_field = forms.FloatField
+    range_type = NumericRange
+
+
+class DateTimeRangeField(BaseRangeField):
+    base_field = forms.DateTimeField
+    range_type = DateTimeTZRange
+
+
+class DateRangeField(BaseRangeField):
+    base_field = forms.DateField
+    range_type = DateRange

+ 14 - 1
django/contrib/postgres/validators.py

@@ -1,7 +1,10 @@
 import copy
 import copy
 
 
 from django.core.exceptions import ValidationError
 from django.core.exceptions import ValidationError
-from django.core.validators import MaxLengthValidator, MinLengthValidator
+from django.core.validators import (
+    MaxLengthValidator, MinLengthValidator, MaxValueValidator,
+    MinValueValidator,
+)
 from django.utils.deconstruct import deconstructible
 from django.utils.deconstruct import deconstructible
 from django.utils.translation import ungettext_lazy, ugettext_lazy as _
 from django.utils.translation import ungettext_lazy, ugettext_lazy as _
 
 
@@ -63,3 +66,13 @@ class KeysValidator(object):
 
 
     def __ne__(self, other):
     def __ne__(self, other):
         return not (self == other)
         return not (self == other)
+
+
+class RangeMaxValueValidator(MaxValueValidator):
+    compare = lambda self, a, b: a.upper > b
+    message = _('Ensure that this range is completely less than or equal to %(limit_value)s.')
+
+
+class RangeMinValueValidator(MinValueValidator):
+    compare = lambda self, a, b: a.lower < b
+    message = _('Ensure that this range is completely greater than or equal to %(limit_value)s.')

+ 1 - 0
docs/conf.py

@@ -130,6 +130,7 @@ intersphinx_mapping = {
     'sphinx': ('http://sphinx-doc.org/', None),
     'sphinx': ('http://sphinx-doc.org/', None),
     'six': ('http://pythonhosted.org/six/', None),
     'six': ('http://pythonhosted.org/six/', None),
     'formtools': ('http://django-formtools.readthedocs.org/en/latest/', None),
     'formtools': ('http://django-formtools.readthedocs.org/en/latest/', None),
+    'psycopg2': ('http://initd.org/psycopg/docs/', None),
 }
 }
 
 
 # Python's docs don't change every week.
 # Python's docs don't change every week.

+ 263 - 0
docs/ref/contrib/postgres/fields.txt

@@ -402,3 +402,266 @@ using in conjunction with lookups on
 
 
     >>> Dog.objects.filter(data__values__contains=['collie'])
     >>> Dog.objects.filter(data__values__contains=['collie'])
     [<Dog: Meg>]
     [<Dog: Meg>]
+
+.. _range-fields:
+
+Range Fields
+------------
+
+There are five range field types, corresponding to the built-in range types in
+PostgreSQL. These fields are used to store a range of values; for example the
+start and end timestamps of an event, or the range of ages an activity is
+suitable for.
+
+All of the range fields translate to :ref:`psycopg2 Range objects
+<psycopg2:adapt-range>` in python, but also accept tuples as input if no bounds
+information is necessary. The default is lower bound included, upper bound
+excluded.
+
+IntegerRangeField
+^^^^^^^^^^^^^^^^^
+
+.. class:: IntegerRangeField(**options)
+
+    Stores a range of integers. Based on an
+    :class:`~django.db.models.IntegerField`. Represented by an ``int4range`` in
+    the database and a :class:`~psycopg2:psycopg2.extras.NumericRange` in
+    Python.
+
+BigIntegerRangeField
+^^^^^^^^^^^^^^^^^^^^
+
+.. class:: BigIntegerRangeField(**options)
+
+    Stores a range of large integers. Based on a
+    :class:`~django.db.models.BigIntegerField`. Represented by an ``int8range``
+    in the database and a :class:`~psycopg2:psycopg2.extras.NumericRange` in
+    Python.
+
+FloatRangeField
+^^^^^^^^^^^^^^^
+
+.. class:: FloatRangeField(**options)
+
+    Stores a range of floating point values. Based on a
+    :class:`~django.db.models.FloatField`. Represented by a ``numrange`` in the
+    database and a :class:`~psycopg2:psycopg2.extras.NumericRange` in Python.
+
+DateTimeRangeField
+^^^^^^^^^^^^^^^^^^
+
+.. class:: DateTimeRangeField(**options)
+
+    Stores a range of timestamps. Based on a
+    :class:`~django.db.models.DateTimeField`. Represented by a ``tztsrange`` in
+    the database and a :class:`~psycopg2:psycopg2.extras.DateTimeTZRange` in
+    Python.
+
+DateRangeField
+^^^^^^^^^^^^^^
+
+.. class:: DateRangeField(**options)
+
+    Stores a range of dates. Based on a
+    :class:`~django.db.models.DateField`. Represented by a ``daterange`` in the
+    database and a :class:`~psycopg2:psycopg2.extras.DateRange` in Python.
+
+Querying Range Fields
+^^^^^^^^^^^^^^^^^^^^^
+
+There are a number of custom lookups and transforms for range fields. They are
+available on all the above fields, but we will use the following example
+model::
+
+    from django.contrib.postgres.fields import IntegerRangeField
+    from django.db import models
+
+    class Event(models.Model):
+        name = models.CharField(max_length=200)
+        ages = IntegerRangeField()
+
+        def __str__(self):  # __unicode__ on Python 2
+            return self.name
+
+We will also use the following example objects::
+
+    >>> Event.objects.create(name='Soft play', ages=(0, 10))
+    >>> Event.objects.create(name='Pub trip', ages=(21, None))
+
+and ``NumericRange``:
+
+    >>> from psycopg2.extras import NumericRange
+
+Containment functions
+~~~~~~~~~~~~~~~~~~~~~
+
+As with other PostgreSQL fields, there are three standard containment
+operators: ``contains``, ``contained_by`` and ``overlap``, using the SQL
+operators ``@>``, ``<@``, and ``&&`` respectively.
+
+.. fieldlookup:: rangefield.contains
+
+contains
+''''''''
+
+    >>> Event.objects.filter(ages__contains=NumericRange(4, 5))
+    [<Event: Soft play>]
+
+.. fieldlookup:: rangefield.contained_by
+
+contained_by
+''''''''''''
+
+    >>> Event.objects.filter(ages__contained_by=NumericRange(0, 15))
+    [<Event: Soft play>]
+
+.. fieldlookup:: rangefield.overlap
+
+overlap
+'''''''
+
+    >>> Event.objects.filter(ages__overlap=NumericRange(8, 12))
+    [<Event: Soft play>]
+
+Comparison functions
+~~~~~~~~~~~~~~~~~~~~
+
+Range fields support the standard lookups: :lookup:`lt`, :lookup:`gt`,
+:lookup:`lte` and :lookup:`gte`. These are not particularly helpful - they
+compare the lower bounds first and then the upper bounds only if necessary.
+This is also the strategy used to order by a range field. It is better to use
+the specific range comparison operators.
+
+.. fieldlookup:: rangefield.fully_lt
+
+fully_lt
+''''''''
+
+The returned ranges are strictly less than the passed range. In other words,
+all the points in the returned range are less than all those in the passed
+range.
+
+    >>> Event.objects.filter(ages__fully_lt=NumericRange(11, 15))
+    [<Event: Soft play>]
+
+.. fieldlookup:: rangefield.fully_gt
+
+fully_gt
+''''''''
+
+The returned ranges are strictly greater than the passed range. In other words,
+the all the points in the returned range are greater than all those in the
+passed range.
+
+    >>> Event.objects.filter(ages__fully_gt=NumericRange(11, 15))
+    [<Event: Pub trip>]
+
+.. fieldlookup:: rangefield.not_lt
+
+not_lt
+''''''
+
+The returned ranges do not contain any points less than the passed range, that
+is the lower bound of the returned range is at least the lower bound of the
+passed range.
+
+    >>> Event.objects.filter(ages__not_lt=NumericRange(0, 15))
+    [<Event: Soft play>, <Event: Pub trip>]
+
+.. fieldlookup:: rangefield.not_gt
+
+not_gt
+''''''
+
+The returned ranges do not contain any points greater than the passed range, that
+is the upper bound of the returned range is at most the upper bound of the
+passed range.
+
+    >>> Event.objects.filter(ages__not_gt=NumericRange(3, 10))
+    [<Event: Soft play>]
+
+.. fieldlookup:: rangefield.adjacent_to
+
+adjacent_to
+'''''''''''
+
+The returned ranges share a bound with the passed range.
+
+    >>> Event.objects.filter(ages__adjacent_to=NumericRange(10, 21))
+    [<Event: Soft play>, <Event: Pub trip>]
+
+Querying using the bounds
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+There are three transforms available for use in queries. You can extract the
+lower or upper bound, or query based on emptiness.
+
+.. fieldlookup:: rangefield.startswith
+
+startswith
+''''''''''
+
+Returned objects have the given lower bound. Can be chained to valid lookups
+for the base field.
+
+    >>> Event.objects.filter(ages__startswith=21)
+    [<Event: Pub trip>]
+
+.. fieldlookup:: rangefield.endswith
+
+endswith
+''''''''
+
+Returned objects have the given upper bound. Can be chained to valid lookups
+for the base field.
+
+    >>> Event.objects.filter(ages__endswith=10)
+    [<Event: Soft play>]
+
+.. fieldlookup:: rangefield.isempty
+
+isempty
+'''''''
+
+Returned objects are empty ranges. Can be chained to valid lookups for a
+:class:`~django.db.models.BooleanField`.
+
+    >>> Event.objects.filter(ages__isempty=True)
+    []
+
+Defining your own range types
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+PostgreSQL allows the definition of custom range types. Django's model and form
+field implementations use base classes below, and psycopg2 provides a
+:func:`~psycopg2:psycopg2.extras.register_range` to allow use of custom range
+types.
+
+.. class:: RangeField(**options)
+
+    Base class for model range fields.
+
+    .. attribute:: base_field
+
+        The model field to use.
+
+    .. attribute:: range_type
+
+        The psycopg2 range type to use.
+
+    .. attribute:: form_field
+
+        The form field class to use. Should be a sublcass of
+        :class:`django.contrib.postgres.forms.BaseRangeField`.
+
+.. class:: django.contrib.postgres.forms.BaseRangeField
+
+    Base class for form range fields.
+
+    .. attribute:: base_field
+
+        The form field to use.
+
+    .. attribute:: range_type
+
+        The psycopg2 range type to use.

+ 45 - 0
docs/ref/contrib/postgres/forms.txt

@@ -154,3 +154,48 @@ HStoreField
         On occasions it may be useful to require or restrict the keys which are
         On occasions it may be useful to require or restrict the keys which are
         valid for a given field. This can be done using the
         valid for a given field. This can be done using the
         :class:`~django.contrib.postgres.validators.KeysValidator`.
         :class:`~django.contrib.postgres.validators.KeysValidator`.
+
+Range Fields
+------------
+
+This group of fields all share similar functionality for accepting range data.
+They are based on :class:`~django.forms.MultiValueField`. They treat one
+omitted value as an unbounded range. They also validate that the lower bound is
+not greater than the upper bound.
+
+IntegerRangeField
+~~~~~~~~~~~~~~~~~
+
+.. class:: IntegerRangeField
+
+    Based on :class:`~django.forms.IntegerField` and translates its input into
+    :class:`~psycopg2:psycopg2.extras.NumericRange`. Default for
+    :class:`~django.contrib.postgres.fields.IntegerRangeField` and
+    :class:`~django.contrib.postgres.fields.BigIntegerRangeField`.
+
+FloatRangeField
+~~~~~~~~~~~~~~~
+
+.. class:: FloatRangeField
+
+    Based on :class:`~django.forms.FloatField` and translates its input into
+    :class:`~psycopg2:psycopg2.extras.NumericRange`. Default for
+    :class:`~django.contrib.postgres.fields.FloatRangeField`.
+
+DateTimeRangeField
+~~~~~~~~~~~~~~~~~~
+
+.. class:: DateTimeRangeField
+
+    Based on :class:`~django.forms.DateTimeField` and translates its input into
+    :class:`~psycopg2:psycopg2.extras.DateTimeTZRange`. Default for
+    :class:`~django.contrib.postgres.fields.DateTimeRangeField`.
+
+DateRangeField
+~~~~~~~~~~~~~~
+
+.. class:: DateRangeField
+
+    Based on :class:`~django.forms.DateField` and translates its input into
+    :class:`~psycopg2:psycopg2.extras.DateRange`. Default for
+    :class:`~django.contrib.postgres.fields.DateRangeField`.

+ 13 - 0
docs/ref/contrib/postgres/validators.txt

@@ -18,3 +18,16 @@ Validators
     .. note::
     .. note::
         Note that this checks only for the existence of a given key, not that
         Note that this checks only for the existence of a given key, not that
         the value of a key is non-empty.
         the value of a key is non-empty.
+
+Range validators
+----------------
+
+.. class:: RangeMaxValueValidator(limit_value, message=None)
+
+    Validates that the upper bound of the range is not greater than
+    ``limit_value``.
+
+.. class:: RangeMinValueValidator(limit_value, message=None)
+
+    Validates that the lower bound of the range is not less than the
+    ``limit_value``.

+ 3 - 3
docs/releases/1.8.txt

@@ -62,9 +62,9 @@ New PostgreSQL specific functionality
 
 
 Django now has a module with extensions for PostgreSQL specific features, such
 Django now has a module with extensions for PostgreSQL specific features, such
 as :class:`~django.contrib.postgres.fields.ArrayField`,
 as :class:`~django.contrib.postgres.fields.ArrayField`,
-:class:`~django.contrib.postgres.fields.HStoreField`, and :lookup:`unaccent`
-lookup. A full breakdown of the features is available :doc:`in the
-documentation </ref/contrib/postgres/index>`.
+:class:`~django.contrib.postgres.fields.HStoreField`, :ref:`range-fields`, and
+:lookup:`unaccent` lookup. A full breakdown of the features is available
+:doc:`in the documentation </ref/contrib/postgres/index>`.
 
 
 New data types
 New data types
 ~~~~~~~~~~~~~~
 ~~~~~~~~~~~~~~

+ 1 - 0
docs/spelling_wordlist

@@ -309,6 +309,7 @@ irc
 iregex
 iregex
 iriencode
 iriencode
 ise
 ise
+isempty
 isnull
 isnull
 iso
 iso
 istartswith
 istartswith

+ 23 - 0
tests/postgres_tests/migrations/0002_create_test_models.py

@@ -92,3 +92,26 @@ class Migration(migrations.Migration):
             bases=None,
             bases=None,
         ),
         ),
     ]
     ]
+
+    pg_92_operations = [
+        migrations.CreateModel(
+            name='RangesModel',
+            fields=[
+                ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
+                ('ints', django.contrib.postgres.fields.IntegerRangeField(null=True, blank=True)),
+                ('bigints', django.contrib.postgres.fields.BigIntegerRangeField(null=True, blank=True)),
+                ('floats', django.contrib.postgres.fields.FloatRangeField(null=True, blank=True)),
+                ('timestamps', django.contrib.postgres.fields.DateTimeRangeField(null=True, blank=True)),
+                ('dates', django.contrib.postgres.fields.DateRangeField(null=True, blank=True)),
+            ],
+            options={
+            },
+            bases=(models.Model,),
+        ),
+    ]
+
+    def apply(self, project_state, schema_editor, collect_sql=False):
+        PG_VERSION = schema_editor.connection.pg_version
+        if PG_VERSION >= 90200:
+            self.operations = self.operations + self.pg_92_operations
+        return super(Migration, self).apply(project_state, schema_editor, collect_sql)

+ 19 - 2
tests/postgres_tests/models.py

@@ -1,5 +1,8 @@
-from django.contrib.postgres.fields import ArrayField, HStoreField
-from django.db import models
+from django.contrib.postgres.fields import (
+    ArrayField, HStoreField, IntegerRangeField, BigIntegerRangeField,
+    FloatRangeField, DateTimeRangeField, DateRangeField,
+)
+from django.db import connection, models
 
 
 
 
 class IntegerArrayModel(models.Model):
 class IntegerArrayModel(models.Model):
@@ -34,6 +37,20 @@ class TextFieldModel(models.Model):
     field = models.TextField()
     field = models.TextField()
 
 
 
 
+# Only create this model for databases which support it
+if connection.vendor == 'postgresql' and connection.pg_version >= 90200:
+    class RangesModel(models.Model):
+        ints = IntegerRangeField(blank=True, null=True)
+        bigints = BigIntegerRangeField(blank=True, null=True)
+        floats = FloatRangeField(blank=True, null=True)
+        timestamps = DateTimeRangeField(blank=True, null=True)
+        dates = DateRangeField(blank=True, null=True)
+else:
+    # create an object with this name so we don't have failing imports
+    class RangesModel(object):
+        pass
+
+
 class ArrayFieldSubclass(ArrayField):
 class ArrayFieldSubclass(ArrayField):
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         super(ArrayFieldSubclass, self).__init__(models.IntegerField())
         super(ArrayFieldSubclass, self).__init__(models.IntegerField())

+ 1 - 0
tests/postgres_tests/test_array.py

@@ -237,6 +237,7 @@ class TestMigrations(TestCase):
         name, path, args, kwargs = field.deconstruct()
         name, path, args, kwargs = field.deconstruct()
         self.assertEqual(path, 'postgres_tests.models.ArrayFieldSubclass')
         self.assertEqual(path, 'postgres_tests.models.ArrayFieldSubclass')
 
 
+    @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required')
     @override_settings(MIGRATION_MODULES={
     @override_settings(MIGRATION_MODULES={
         "postgres_tests": "postgres_tests.array_default_migrations",
         "postgres_tests": "postgres_tests.array_default_migrations",
     })
     })

+ 376 - 0
tests/postgres_tests/test_ranges.py

@@ -0,0 +1,376 @@
+import datetime
+import json
+import unittest
+
+from django import forms
+from django.contrib.postgres import forms as pg_forms, fields as pg_fields
+from django.contrib.postgres.validators import RangeMaxValueValidator, RangeMinValueValidator
+from django.core import exceptions, serializers
+from django.db import connection
+from django.test import TestCase
+from django.utils import timezone
+
+from psycopg2.extras import NumericRange, DateTimeTZRange, DateRange
+
+from .models import RangesModel
+
+
+def skipUnlessPG92(test):
+    if not connection.vendor == 'postgresql':
+        return unittest.skip('PostgreSQL required')(test)
+    PG_VERSION = connection.pg_version
+    if PG_VERSION < 90200:
+        return unittest.skip('PostgreSQL >= 9.2 required')(test)
+    return test
+
+
+@skipUnlessPG92
+class TestSaveLoad(TestCase):
+
+    def test_all_fields(self):
+        now = timezone.now()
+        instance = RangesModel(
+            ints=NumericRange(0, 10),
+            bigints=NumericRange(10, 20),
+            floats=NumericRange(20, 30),
+            timestamps=DateTimeTZRange(now - datetime.timedelta(hours=1), now),
+            dates=DateRange(now.date() - datetime.timedelta(days=1), now.date()),
+        )
+        instance.save()
+        loaded = RangesModel.objects.get()
+        self.assertEqual(instance.ints, loaded.ints)
+        self.assertEqual(instance.bigints, loaded.bigints)
+        self.assertEqual(instance.floats, loaded.floats)
+        self.assertEqual(instance.timestamps, loaded.timestamps)
+        self.assertEqual(instance.dates, loaded.dates)
+
+    def test_range_object(self):
+        r = NumericRange(0, 10)
+        instance = RangesModel(ints=r)
+        instance.save()
+        loaded = RangesModel.objects.get()
+        self.assertEqual(r, loaded.ints)
+
+    def test_tuple(self):
+        instance = RangesModel(ints=(0, 10))
+        instance.save()
+        loaded = RangesModel.objects.get()
+        self.assertEqual(NumericRange(0, 10), loaded.ints)
+
+    def test_range_object_boundaries(self):
+        r = NumericRange(0, 10, '[]')
+        instance = RangesModel(floats=r)
+        instance.save()
+        loaded = RangesModel.objects.get()
+        self.assertEqual(r, loaded.floats)
+        self.assertTrue(10 in loaded.floats)
+
+    def test_unbounded(self):
+        r = NumericRange(None, None, '()')
+        instance = RangesModel(floats=r)
+        instance.save()
+        loaded = RangesModel.objects.get()
+        self.assertEqual(r, loaded.floats)
+
+    def test_empty(self):
+        r = NumericRange(empty=True)
+        instance = RangesModel(ints=r)
+        instance.save()
+        loaded = RangesModel.objects.get()
+        self.assertEqual(r, loaded.ints)
+
+    def test_null(self):
+        instance = RangesModel(ints=None)
+        instance.save()
+        loaded = RangesModel.objects.get()
+        self.assertEqual(None, loaded.ints)
+
+
+@skipUnlessPG92
+class TestQuerying(TestCase):
+
+    @classmethod
+    def setUpTestData(cls):
+        cls.objs = [
+            RangesModel.objects.create(ints=NumericRange(0, 10)),
+            RangesModel.objects.create(ints=NumericRange(5, 15)),
+            RangesModel.objects.create(ints=NumericRange(None, 0)),
+            RangesModel.objects.create(ints=NumericRange(empty=True)),
+            RangesModel.objects.create(ints=None),
+        ]
+
+    def test_exact(self):
+        self.assertSequenceEqual(
+            RangesModel.objects.filter(ints__exact=NumericRange(0, 10)),
+            [self.objs[0]],
+        )
+
+    def test_isnull(self):
+        self.assertSequenceEqual(
+            RangesModel.objects.filter(ints__isnull=True),
+            [self.objs[4]],
+        )
+
+    def test_isempty(self):
+        self.assertSequenceEqual(
+            RangesModel.objects.filter(ints__isempty=True),
+            [self.objs[3]],
+        )
+
+    def test_contains(self):
+        self.assertSequenceEqual(
+            RangesModel.objects.filter(ints__contains=8),
+            [self.objs[0], self.objs[1]],
+        )
+
+    def test_contains_range(self):
+        self.assertSequenceEqual(
+            RangesModel.objects.filter(ints__contains=NumericRange(3, 8)),
+            [self.objs[0]],
+        )
+
+    def test_contained_by(self):
+        self.assertSequenceEqual(
+            RangesModel.objects.filter(ints__contained_by=NumericRange(0, 20)),
+            [self.objs[0], self.objs[1], self.objs[3]],
+        )
+
+    def test_overlap(self):
+        self.assertSequenceEqual(
+            RangesModel.objects.filter(ints__overlap=NumericRange(3, 8)),
+            [self.objs[0], self.objs[1]],
+        )
+
+    def test_fully_lt(self):
+        self.assertSequenceEqual(
+            RangesModel.objects.filter(ints__fully_lt=NumericRange(5, 10)),
+            [self.objs[2]],
+        )
+
+    def test_fully_gt(self):
+        self.assertSequenceEqual(
+            RangesModel.objects.filter(ints__fully_gt=NumericRange(5, 10)),
+            [],
+        )
+
+    def test_not_lt(self):
+        self.assertSequenceEqual(
+            RangesModel.objects.filter(ints__not_lt=NumericRange(5, 10)),
+            [self.objs[1]],
+        )
+
+    def test_not_gt(self):
+        self.assertSequenceEqual(
+            RangesModel.objects.filter(ints__not_gt=NumericRange(5, 10)),
+            [self.objs[0], self.objs[2]],
+        )
+
+    def test_adjacent_to(self):
+        self.assertSequenceEqual(
+            RangesModel.objects.filter(ints__adjacent_to=NumericRange(0, 5)),
+            [self.objs[1], self.objs[2]],
+        )
+
+    def test_startswith(self):
+        self.assertSequenceEqual(
+            RangesModel.objects.filter(ints__startswith=0),
+            [self.objs[0]],
+        )
+
+    def test_endswith(self):
+        self.assertSequenceEqual(
+            RangesModel.objects.filter(ints__endswith=0),
+            [self.objs[2]],
+        )
+
+    def test_startswith_chaining(self):
+        self.assertSequenceEqual(
+            RangesModel.objects.filter(ints__startswith__gte=0),
+            [self.objs[0], self.objs[1]],
+        )
+
+
+@skipUnlessPG92
+class TestSerialization(TestCase):
+    test_data = (
+        '[{"fields": {"ints": "{\\"upper\\": 10, \\"lower\\": 0, '
+        '\\"bounds\\": \\"[)\\"}", "floats": "{\\"empty\\": true}", '
+        '"bigints": null, "timestamps": null, "dates": null}, '
+        '"model": "postgres_tests.rangesmodel", "pk": null}]'
+    )
+
+    def test_dumping(self):
+        instance = RangesModel(ints=NumericRange(0, 10), floats=NumericRange(empty=True))
+        data = serializers.serialize('json', [instance])
+        dumped = json.loads(data)
+        dumped[0]['fields']['ints'] = json.loads(dumped[0]['fields']['ints'])
+        check = json.loads(self.test_data)
+        check[0]['fields']['ints'] = json.loads(check[0]['fields']['ints'])
+        self.assertEqual(dumped, check)
+
+    def test_loading(self):
+        instance = list(serializers.deserialize('json', self.test_data))[0].object
+        self.assertEqual(instance.ints, NumericRange(0, 10))
+        self.assertEqual(instance.floats, NumericRange(empty=True))
+        self.assertEqual(instance.dates, None)
+
+
+class TestValidators(TestCase):
+
+    def test_max(self):
+        validator = RangeMaxValueValidator(5)
+        validator(NumericRange(0, 5))
+        with self.assertRaises(exceptions.ValidationError) as cm:
+            validator(NumericRange(0, 10))
+        self.assertEqual(cm.exception.messages[0], 'Ensure that this range is completely less than or equal to 5.')
+        self.assertEqual(cm.exception.code, 'max_value')
+
+    def test_min(self):
+        validator = RangeMinValueValidator(5)
+        validator(NumericRange(10, 15))
+        with self.assertRaises(exceptions.ValidationError) as cm:
+            validator(NumericRange(0, 10))
+        self.assertEqual(cm.exception.messages[0], 'Ensure that this range is completely greater than or equal to 5.')
+        self.assertEqual(cm.exception.code, 'min_value')
+
+
+class TestFormField(TestCase):
+
+    def test_valid_integer(self):
+        field = pg_forms.IntegerRangeField()
+        value = field.clean(['1', '2'])
+        self.assertEqual(value, NumericRange(1, 2))
+
+    def test_valid_floats(self):
+        field = pg_forms.FloatRangeField()
+        value = field.clean(['1.12345', '2.001'])
+        self.assertEqual(value, NumericRange(1.12345, 2.001))
+
+    def test_valid_timestamps(self):
+        field = pg_forms.DateTimeRangeField()
+        value = field.clean(['01/01/2014 00:00:00', '02/02/2014 12:12:12'])
+        lower = datetime.datetime(2014, 1, 1, 0, 0, 0)
+        upper = datetime.datetime(2014, 2, 2, 12, 12, 12)
+        self.assertEqual(value, DateTimeTZRange(lower, upper))
+
+    def test_valid_dates(self):
+        field = pg_forms.DateRangeField()
+        value = field.clean(['01/01/2014', '02/02/2014'])
+        lower = datetime.date(2014, 1, 1)
+        upper = datetime.date(2014, 2, 2)
+        self.assertEqual(value, DateRange(lower, upper))
+
+    def test_using_split_datetime_widget(self):
+        class SplitDateTimeRangeField(pg_forms.DateTimeRangeField):
+            base_field = forms.SplitDateTimeField
+
+        class SplitForm(forms.Form):
+            field = SplitDateTimeRangeField()
+
+        form = SplitForm()
+        self.assertHTMLEqual(str(form), '''
+            <tr>
+                <th>
+                <label for="id_field_0">Field:</label>
+                </th>
+                <td>
+                    <input id="id_field_0_0" name="field_0_0" type="text" />
+                    <input id="id_field_0_1" name="field_0_1" type="text" />
+                    <input id="id_field_1_0" name="field_1_0" type="text" />
+                    <input id="id_field_1_1" name="field_1_1" type="text" />
+                </td>
+            </tr>
+        ''')
+        form = SplitForm({
+            'field_0_0': '01/01/2014',
+            'field_0_1': '00:00:00',
+            'field_1_0': '02/02/2014',
+            'field_1_1': '12:12:12',
+        })
+        self.assertTrue(form.is_valid())
+        lower = datetime.datetime(2014, 1, 1, 0, 0, 0)
+        upper = datetime.datetime(2014, 2, 2, 12, 12, 12)
+        self.assertEqual(form.cleaned_data['field'], DateTimeTZRange(lower, upper))
+
+    def test_none(self):
+        field = pg_forms.IntegerRangeField(required=False)
+        value = field.clean(['', ''])
+        self.assertEqual(value, None)
+
+    def test_rendering(self):
+        class RangeForm(forms.Form):
+            ints = pg_forms.IntegerRangeField()
+
+        self.assertHTMLEqual(str(RangeForm()), '''
+        <tr>
+            <th><label for="id_ints_0">Ints:</label></th>
+            <td>
+                <input id="id_ints_0" name="ints_0" type="number" />
+                <input id="id_ints_1" name="ints_1" type="number" />
+            </td>
+        </tr>
+        ''')
+
+    def test_lower_bound_higher(self):
+        field = pg_forms.IntegerRangeField()
+        with self.assertRaises(exceptions.ValidationError) as cm:
+            field.clean(['10', '2'])
+        self.assertEqual(cm.exception.messages[0], 'The start of the range must not exceed the end of the range.')
+        self.assertEqual(cm.exception.code, 'bound_ordering')
+
+    def test_open(self):
+        field = pg_forms.IntegerRangeField()
+        value = field.clean(['', '0'])
+        self.assertEqual(value, NumericRange(None, 0))
+
+    def test_incorrect_data_type(self):
+        field = pg_forms.IntegerRangeField()
+        with self.assertRaises(exceptions.ValidationError) as cm:
+            field.clean('1')
+        self.assertEqual(cm.exception.messages[0], 'Enter two valid values.')
+        self.assertEqual(cm.exception.code, 'invalid')
+
+    def test_invalid_lower(self):
+        field = pg_forms.IntegerRangeField()
+        with self.assertRaises(exceptions.ValidationError) as cm:
+            field.clean(['a', '2'])
+        self.assertEqual(cm.exception.messages[0], 'Enter a whole number.')
+
+    def test_invalid_upper(self):
+        field = pg_forms.IntegerRangeField()
+        with self.assertRaises(exceptions.ValidationError) as cm:
+            field.clean(['1', 'b'])
+        self.assertEqual(cm.exception.messages[0], 'Enter a whole number.')
+
+    def test_required(self):
+        field = pg_forms.IntegerRangeField(required=True)
+        with self.assertRaises(exceptions.ValidationError) as cm:
+            field.clean(['', ''])
+        self.assertEqual(cm.exception.messages[0], 'This field is required.')
+        value = field.clean([1, ''])
+        self.assertEqual(value, NumericRange(1, None))
+
+    def test_model_field_formfield_integer(self):
+        model_field = pg_fields.IntegerRangeField()
+        form_field = model_field.formfield()
+        self.assertIsInstance(form_field, pg_forms.IntegerRangeField)
+
+    def test_model_field_formfield_biginteger(self):
+        model_field = pg_fields.BigIntegerRangeField()
+        form_field = model_field.formfield()
+        self.assertIsInstance(form_field, pg_forms.IntegerRangeField)
+
+    def test_model_field_formfield_float(self):
+        model_field = pg_fields.FloatRangeField()
+        form_field = model_field.formfield()
+        self.assertIsInstance(form_field, pg_forms.FloatRangeField)
+
+    def test_model_field_formfield_date(self):
+        model_field = pg_fields.DateRangeField()
+        form_field = model_field.formfield()
+        self.assertIsInstance(form_field, pg_forms.DateRangeField)
+
+    def test_model_field_formfield_datetime(self):
+        model_field = pg_fields.DateTimeRangeField()
+        form_field = model_field.formfield()
+        self.assertIsInstance(form_field, pg_forms.DateTimeRangeField)