Selaa lähdekoodia

Added HStoreField.

Thanks to `django-hstore` for inspiration in some areas, and many people
for reviews.
Marc Tamlyn 11 vuotta sitten
vanhempi
commit
36f514f065

+ 1 - 0
django/contrib/postgres/__init__.py

@@ -0,0 +1 @@
+default_app_config = 'django.contrib.postgres.apps.PostgresConfig'

+ 13 - 0
django/contrib/postgres/apps.py

@@ -0,0 +1,13 @@
+from django.apps import AppConfig
+from django.db.backends.signals import connection_created
+from django.utils.translation import ugettext_lazy as _
+
+from .signals import register_hstore_handler
+
+
+class PostgresConfig(AppConfig):
+    name = 'django.contrib.postgres'
+    verbose_name = _('PostgreSQL extensions')
+
+    def ready(self):
+        connection_created.connect(register_hstore_handler)

+ 1 - 0
django/contrib/postgres/fields/__init__.py

@@ -1 +1,2 @@
 from .array import *  # NOQA
+from .hstore import *  # NOQA

+ 1 - 1
django/contrib/postgres/fields/array.py

@@ -168,7 +168,7 @@ class ArrayContainsLookup(Lookup):
         lhs, lhs_params = self.process_lhs(qn, connection)
         rhs, rhs_params = self.process_rhs(qn, connection)
         params = lhs_params + rhs_params
-        type_cast = self.lhs.source.db_type(connection)
+        type_cast = self.lhs.output_field.db_type(connection)
         return '%s @> %s::%s' % (lhs, rhs, type_cast), params
 
 

+ 145 - 0
django/contrib/postgres/fields/hstore.py

@@ -0,0 +1,145 @@
+import json
+
+from django.contrib.postgres import forms
+from django.contrib.postgres.fields.array import ArrayField
+from django.core import exceptions
+from django.db.models import Field, Lookup, Transform, TextField
+from django.utils import six
+from django.utils.translation import ugettext_lazy as _
+
+
+__all__ = ['HStoreField']
+
+
+class HStoreField(Field):
+    empty_strings_allowed = False
+    description = _('Map of strings to strings')
+    default_error_messages = {
+        'not_a_string': _('The value of "%(key)s" is not a string.'),
+    }
+
+    def db_type(self, connection):
+        return 'hstore'
+
+    def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False):
+        if lookup_type == 'contains':
+            return [self.get_prep_value(value)]
+        return super(HStoreField, self).get_db_prep_lookup(lookup_type, value,
+                connection, prepared=False)
+
+    def get_transform(self, name):
+        transform = super(HStoreField, self).get_transform(name)
+        if transform:
+            return transform
+        return KeyTransformFactory(name)
+
+    def validate(self, value, model_instance):
+        super(HStoreField, self).validate(value, model_instance)
+        for key, val in value.items():
+            if not isinstance(val, six.string_types):
+                raise exceptions.ValidationError(
+                    self.error_messages['not_a_string'],
+                    code='not_a_string',
+                    params={'key': key},
+                )
+
+    def to_python(self, value):
+        if isinstance(value, six.string_types):
+            value = json.loads(value)
+        return value
+
+    def value_to_string(self, obj):
+        value = self._get_val_from_obj(obj)
+        return json.dumps(value)
+
+    def formfield(self, **kwargs):
+        defaults = {
+            'form_class': forms.HStoreField,
+        }
+        defaults.update(kwargs)
+        return super(HStoreField, self).formfield(**defaults)
+
+
+@HStoreField.register_lookup
+class HStoreContainsLookup(Lookup):
+    lookup_name = 'contains'
+
+    def as_sql(self, qn, connection):
+        lhs, lhs_params = self.process_lhs(qn, connection)
+        rhs, rhs_params = self.process_rhs(qn, connection)
+        params = lhs_params + rhs_params
+        return '%s @> %s' % (lhs, rhs), params
+
+
+@HStoreField.register_lookup
+class HStoreContainedByLookup(Lookup):
+    lookup_name = 'contained_by'
+
+    def as_sql(self, qn, connection):
+        lhs, lhs_params = self.process_lhs(qn, connection)
+        rhs, rhs_params = self.process_rhs(qn, connection)
+        params = lhs_params + rhs_params
+        return '%s <@ %s' % (lhs, rhs), params
+
+
+@HStoreField.register_lookup
+class HasKeyLookup(Lookup):
+    lookup_name = 'has_key'
+
+    def as_sql(self, qn, connection):
+        lhs, lhs_params = self.process_lhs(qn, connection)
+        rhs, rhs_params = self.process_rhs(qn, connection)
+        params = lhs_params + rhs_params
+        return '%s ? %s' % (lhs, rhs), params
+
+
+@HStoreField.register_lookup
+class HasKeysLookup(Lookup):
+    lookup_name = 'has_keys'
+
+    def as_sql(self, qn, connection):
+        lhs, lhs_params = self.process_lhs(qn, connection)
+        rhs, rhs_params = self.process_rhs(qn, connection)
+        params = lhs_params + rhs_params
+        return '%s ?& %s' % (lhs, rhs), params
+
+
+class KeyTransform(Transform):
+    output_field = TextField()
+
+    def __init__(self, key_name, *args, **kwargs):
+        super(KeyTransform, self).__init__(*args, **kwargs)
+        self.key_name = key_name
+
+    def as_sql(self, qn, connection):
+        lhs, params = qn.compile(self.lhs)
+        return "%s -> '%s'" % (lhs, self.key_name), params
+
+
+class KeyTransformFactory(object):
+
+    def __init__(self, key_name):
+        self.key_name = key_name
+
+    def __call__(self, *args, **kwargs):
+        return KeyTransform(self.key_name, *args, **kwargs)
+
+
+@HStoreField.register_lookup
+class KeysTransform(Transform):
+    lookup_name = 'keys'
+    output_field = ArrayField(TextField())
+
+    def as_sql(self, qn, connection):
+        lhs, params = qn.compile(self.lhs)
+        return 'akeys(%s)' % lhs, params
+
+
+@HStoreField.register_lookup
+class ValuesTransform(Transform):
+    lookup_name = 'values'
+    output_field = ArrayField(TextField())
+
+    def as_sql(self, qn, connection):
+        lhs, params = qn.compile(self.lhs)
+        return 'avals(%s)' % lhs, params

+ 1 - 0
django/contrib/postgres/forms/__init__.py

@@ -1 +1,2 @@
 from .array import *  # NOQA
+from .hstore import *  # NOQA

+ 37 - 0
django/contrib/postgres/forms/hstore.py

@@ -0,0 +1,37 @@
+import json
+
+from django import forms
+from django.core.exceptions import ValidationError
+from django.utils import six
+from django.utils.translation import ugettext_lazy as _
+
+
+__all__ = ['HStoreField']
+
+
+class HStoreField(forms.CharField):
+    """A field for HStore data which accepts JSON input."""
+    widget = forms.Textarea
+    default_error_messages = {
+        'invalid_json': _('Could not load JSON data.'),
+    }
+
+    def prepare_value(self, value):
+        if isinstance(value, dict):
+            return json.dumps(value)
+        return value
+
+    def to_python(self, value):
+        if not value:
+            return {}
+        try:
+            value = json.loads(value)
+        except ValueError:
+            raise ValidationError(
+                self.error_messages['invalid_json'],
+                code='invalid_json',
+            )
+        # Cast everything to strings for ease.
+        for key, val in value.items():
+            value[key] = six.text_type(val)
+        return value

+ 34 - 0
django/contrib/postgres/operations.py

@@ -0,0 +1,34 @@
+from django.contrib.postgres.signals import register_hstore_handler
+from django.db.migrations.operations.base import Operation
+
+
+class CreateExtension(Operation):
+    reversible = True
+
+    def __init__(self, name):
+        self.name = name
+
+    def state_forwards(self, app_label, state):
+        pass
+
+    def database_forwards(self, app_label, schema_editor, from_state, to_state):
+        schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % self.name)
+
+    def database_backwards(self, app_label, schema_editor, from_state, to_state):
+        schema_editor.execute("DROP EXTENSION %s" % self.name)
+
+    def describe(self):
+        return "Creates extension %s" % self.name
+
+
+class HStoreExtension(CreateExtension):
+
+    def __init__(self):
+        self.name = 'hstore'
+
+    def database_forwards(self, app_label, schema_editor, from_state, to_state):
+        super(HStoreExtension, self).database_forwards(app_label, schema_editor, from_state, to_state)
+        # Register hstore straight away as it cannot be done before the
+        # extension is installed, a subsequent data migration would use the
+        # same connection
+        register_hstore_handler(schema_editor.connection)

+ 25 - 0
django/contrib/postgres/signals.py

@@ -0,0 +1,25 @@
+from django.utils import six
+
+from psycopg2 import ProgrammingError
+from psycopg2.extras import register_hstore
+
+
+def register_hstore_handler(connection, **kwargs):
+    if connection.vendor != 'postgresql':
+        return
+
+    try:
+        if six.PY2:
+            register_hstore(connection.connection, globally=True, unicode=True)
+        else:
+            register_hstore(connection.connection, globally=True)
+    except ProgrammingError:
+        # Hstore is not available on the database.
+        #
+        # If someone tries to create an hstore field it will error there.
+        # This is necessary as someone may be using PSQL without extensions
+        # installed but be using other features of contrib.postgres.
+        #
+        # This is also needed in order to create the connection in order to
+        # install the hstore extension.
+        pass

+ 50 - 1
django/contrib/postgres/validators.py

@@ -1,5 +1,9 @@
+import copy
+
+from django.core.exceptions import ValidationError
 from django.core.validators import MaxLengthValidator, MinLengthValidator
-from django.utils.translation import ungettext_lazy
+from django.utils.deconstruct import deconstructible
+from django.utils.translation import ungettext_lazy, ugettext_lazy as _
 
 
 class ArrayMaxLengthValidator(MaxLengthValidator):
@@ -14,3 +18,48 @@ class ArrayMinLengthValidator(MinLengthValidator):
         'List contains %(show_value)d item, it should contain no fewer than %(limit_value)d.',
         'List contains %(show_value)d items, it should contain no fewer than %(limit_value)d.',
         'limit_value')
+
+
+@deconstructible
+class KeysValidator(object):
+    """A validator designed for HStore to require/restrict keys."""
+
+    messages = {
+        'missing_keys': _('Some keys were missing: %(keys)s'),
+        'extra_keys': _('Some unknown keys were provided: %(keys)s'),
+    }
+    strict = False
+
+    def __init__(self, keys, strict=False, messages=None):
+        self.keys = set(keys)
+        self.strict = strict
+        if messages is not None:
+            self.messages = copy.copy(self.messages)
+            self.messages.update(messages)
+
+    def __call__(self, value):
+        keys = set(value.keys())
+        missing_keys = self.keys - keys
+        if missing_keys:
+            raise ValidationError(self.messages['missing_keys'],
+                code='missing_keys',
+                params={'keys': ', '.join(missing_keys)},
+            )
+        if self.strict:
+            extra_keys = keys - self.keys
+            if extra_keys:
+                raise ValidationError(self.messages['extra_keys'],
+                    code='extra_keys',
+                    params={'keys': ', '.join(extra_keys)},
+                )
+
+    def __eq__(self, other):
+        return (
+            isinstance(other, self.__class__)
+            and (self.keys == other.keys)
+            and (self.messages == other.messages)
+            and (self.strict == other.strict)
+        )
+
+    def __ne__(self, other):
+        return not (self == other)

+ 164 - 2
docs/ref/contrib/postgres/fields.txt

@@ -61,8 +61,8 @@ ArrayField
     When nesting ``ArrayField``, whether you use the `size` parameter or not,
     PostgreSQL requires that the arrays are rectangular::
 
-        from django.db import models
         from django.contrib.postgres.fields import ArrayField
+        from django.db import models
 
         class Board(models.Model):
             pieces = ArrayField(ArrayField(models.IntegerField()))
@@ -95,7 +95,7 @@ We will use the following example model::
         name = models.CharField(max_length=200)
         tags = ArrayField(models.CharField(max_length=200), blank=True)
 
-        def __str__(self):  # __unicode__ on python 2
+        def __str__(self):  # __unicode__ on Python 2
             return self.name
 
 .. fieldlookup:: arrayfield.contains
@@ -240,3 +240,165 @@ At present using :attr:`~django.db.models.Field.db_index` will create a
 ``btree`` index. This does not offer particularly significant help to querying.
 A more useful index is a ``GIN`` index, which you should create using a
 :class:`~django.db.migrations.operations.RunSQL` operation.
+
+HStoreField
+-----------
+
+.. class:: HStoreField(**options)
+
+    A field for storing mappings of strings to strings. The Python data type
+    used is a ``dict``.
+
+.. note::
+
+    On occasions it may be useful to require or restrict the keys which are
+    valid for a given field. This can be done using the
+    :class:`~django.contrib.postgres.validators.KeysValidator`.
+
+Querying HStoreField
+^^^^^^^^^^^^^^^^^^^^
+
+In addition to the ability to query by key, there are a number of custom
+lookups available for ``HStoreField``.
+
+We will use the following example model::
+
+    from django.contrib.postgres.fields import HStoreField
+    from django.db import models
+
+    class Dog(models.Model):
+        name = models.CharField(max_length=200)
+        data = HStoreField()
+
+        def __str__(self):  # __unicode__ on Python 2
+            return self.name
+
+.. fieldlookup:: hstorefield.key
+
+Key lookups
+~~~~~~~~~~~
+
+To query based on a given key, you simply use that key as the lookup name::
+
+    >>> Dog.objects.create(name='Rufus', data={'breed': 'labrador'})
+    >>> Dog.objects.create(name='Meg', data={'breed': 'collie'})
+
+    >>> Dog.objects.filter(data__breed='collie')
+    [<Dog: Meg>]
+
+You can chain other lookups after key lookups::
+
+    >>> Dog.objects.filter(data__breed__contains='l')
+    [<Dog: Rufus>, Dog: Meg>]
+
+If the key you wish to query by clashes with the name of another lookup, you
+need to use the :lookup:`hstorefield.contains` lookup instead.
+
+.. warning::
+
+    Since any string could be a key in a hstore value, any lookup other than
+    those listed below will be interpreted as a key lookup. No errors are
+    raised. Be extra careful for typing mistakes, and always check your queries
+    work as you intend.
+
+.. fieldlookup:: hstorefield.contains
+
+contains
+~~~~~~~~
+
+The :lookup:`contains` lookup is overridden on
+:class:`~django.contrib.postgres.fields.HStoreField`. The returned objects are
+those where the given ``dict`` of key-value pairs are all contained in the
+field. It uses the SQL operator ``@>``. For example::
+
+    >>> Dog.objects.create(name='Rufus', data={'breed': 'labrador', 'owner': 'Bob'})
+    >>> Dog.objects.create(name='Meg', data={'breed': 'collie', 'owner': 'Bob'})
+    >>> Dog.objects.create(name='Fred', data={})
+
+    >>> Dog.objects.filter(data__contains={'owner': 'Bob'})
+    [<Dog: Rufus>, <Dog: Meg>]
+
+    >>> Dog.objects.filter(data__contains={'breed': 'collie'})
+    [<Dog: Meg>]
+
+.. fieldlookup:: hstorefield.contained_by
+
+contained_by
+~~~~~~~~~~~~
+
+This is the inverse of the :lookup:`contains <hstorefield.contains>` lookup -
+the objects returned will be those where the key-value pairs on the object are
+a subset of those in the value passed. It uses the SQL operator ``<@``. For
+example::
+
+    >>> Dog.objects.create(name='Rufus', data={'breed': 'labrador', 'owner': 'Bob'})
+    >>> Dog.objects.create(name='Meg', data={'breed': 'collie', 'owner': 'Bob'})
+    >>> Dog.objects.create(name='Fred', data={})
+
+    >>> Dog.objects.filter(data__contained_by={'breed': 'collie', 'owner': 'Bob'})
+    [<Dog: Meg>, <Dog: Fred>]
+
+    >>> Dog.objects.filter(data__contained_by={'breed': 'collie'})
+    [<Dog: Fred>]
+
+.. fieldlookup:: hstorefield.has_key
+
+has_key
+~~~~~~~
+
+Returns objects where the given key is in the data. Uses the SQL operator
+``?``. For example::
+
+    >>> Dog.objects.create(name='Rufus', data={'breed': 'labrador'})
+    >>> Dog.objects.create(name='Meg', data={'breed': 'collie', 'owner': 'Bob'})
+
+    >>> Dog.objects.filter(data__has_key='owner')
+    [<Dog: Meg>]
+
+.. fieldlookup:: hstorefield.has_keys
+
+has_keys
+~~~~~~~~
+
+Returns objects where all of the given keys are in the data. Uses the SQL operator
+``?&``. For example::
+
+    >>> Dog.objects.create(name='Rufus', data={})
+    >>> Dog.objects.create(name='Meg', data={'breed': 'collie', 'owner': 'Bob'})
+
+    >>> Dog.objects.filter(data__has_keys=['breed', 'owner'])
+    [<Dog: Meg>]
+
+.. fieldlookup:: hstorefield.keys
+
+keys
+~~~~
+
+Returns objects where the array of keys is the given value. Note that the order
+is not guaranteed to be reliable, so this transform is mainly useful for using
+in conjunction with lookups on
+:class:`~django.contrib.postgres.fields.ArrayField`. Uses the SQL function
+``akeys()``. For example::
+
+    >>> Dog.objects.create(name='Rufus', data={'toy': 'bone'})
+    >>> Dog.objects.create(name='Meg', data={'breed': 'collie', 'owner': 'Bob'})
+
+    >>> Dog.objects.filter(data__keys__overlap=['breed', 'toy'])
+    [<Dog: Rufus>, <Dog: Meg>]
+
+.. fieldlookup:: hstorefield.values
+
+values
+~~~~~~
+
+Returns objects where the array of values is the given value. Note that the
+order is not guaranteed to be reliable, so this transform is mainly useful for
+using in conjunction with lookups on
+:class:`~django.contrib.postgres.fields.ArrayField`. Uses the SQL function
+``avalues()``. For example::
+
+    >>> Dog.objects.create(name='Rufus', data={'breed': 'labrador'})
+    >>> Dog.objects.create(name='Meg', data={'breed': 'collie', 'owner': 'Bob'})
+
+    >>> Dog.objects.filter(data__values__contains=['collie'])
+    [<Dog: Meg>]

+ 20 - 0
docs/ref/contrib/postgres/forms.txt

@@ -133,3 +133,23 @@ SplitArrayField
             ['1', '2', '']  # -> [1, 2]
             ['1', '', '3']  # -> [1, None, 3]
             ['', '2', '']  # -> [None, 2]
+
+HStoreField
+-----------
+
+.. class:: HStoreField
+
+    A field which accepts JSON encoded data for an
+    :class:`~django.contrib.postgres.fields.HStoreField`. It will cast all the
+    values to strings. It is represented by an HTML ``<textarea>``.
+
+    .. admonition:: User friendly forms
+
+        ``HStoreField`` is not particularly user friendly in most cases,
+        however it is a useful way to format data from a client-side widget for
+        submission to the server.
+
+    .. note::
+        On occasions it may be useful to require or restrict the keys which are
+        valid for a given field. This can be done using the
+        :class:`~django.contrib.postgres.validators.KeysValidator`.

+ 2 - 0
docs/ref/contrib/postgres/index.txt

@@ -26,3 +26,5 @@ a number of PostgreSQL specific data types.
 
     fields
     forms
+    operations
+    validators

+ 27 - 0
docs/ref/contrib/postgres/operations.txt

@@ -0,0 +1,27 @@
+Database migration operations
+=============================
+
+All of these :doc:`operations </ref/migration-operations>` are available from
+the ``django.contrib.postgres.operations`` module.
+
+.. currentmodule:: django.contrib.postgres.operations
+
+CreateExtension
+---------------
+
+.. class:: CreateExtension(name)
+
+    An ``Operation`` subclass which installs PostgreSQL extensions.
+
+    .. attribute:: name
+
+        This is a required argument. The name of the extension to be installed.
+
+HStoreExtension
+---------------
+
+.. class:: HStoreExtension()
+
+    A subclass of :class:`~django.contrib.postgres.operations.CreateExtension`
+    which will install the ``hstore`` extension and also immediately set up the
+    connection to interpret hstore data.

+ 20 - 0
docs/ref/contrib/postgres/validators.txt

@@ -0,0 +1,20 @@
+==========
+Validators
+==========
+
+.. module:: django.contrib.postgres.validators
+
+``KeysValidator``
+-----------------
+
+.. class:: KeysValidator(keys, strict=False, messages=None)
+
+    Validates that the given keys are contained in the value. If ``strict`` is
+    ``True``, then it also checks that there are no other keys present.
+
+    The ``messages`` passed should be a dict containing the keys
+    ``missing_keys`` and/or ``extra_keys``.
+
+    .. note::
+        Note that this checks only for the existence of a given key, not that
+        the value of a key is non-empty.

+ 8 - 0
docs/releases/1.8.txt

@@ -35,6 +35,14 @@ site.
 
 .. _django-secure: https://pypi.python.org/pypi/django-secure
 
+New PostgreSQL specific functionality
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Django now has a module with extensions for PostgreSQL specific features, such
+as :class:`~django.contrib.postgres.fields.ArrayField` and
+:class:`~django.contrib.postgres.fields.HStoreField`. A full breakdown of the
+features is available :doc:`in the documentation</ref/contrib/postgres/index>`.
+
 New data types
 ~~~~~~~~~~~~~~
 

+ 15 - 0
tests/postgres_tests/migrations/0001_setup_extensions.py

@@ -0,0 +1,15 @@
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+
+from django.contrib.postgres.operations import HStoreExtension
+from django.db import models, migrations
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+    ]
+
+    operations = [
+        HStoreExtension(),
+    ]

+ 76 - 0
tests/postgres_tests/migrations/0002_create_test_models.py

@@ -0,0 +1,76 @@
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+
+from django.db import models, migrations
+import django.contrib.postgres.fields
+import django.contrib.postgres.fields.hstore
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ('postgres_tests', '0001_setup_extensions'),
+    ]
+
+    operations = [
+        migrations.CreateModel(
+            name='CharArrayModel',
+            fields=[
+                ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
+                ('field', django.contrib.postgres.fields.ArrayField(models.CharField(max_length=10), size=None)),
+            ],
+            options={
+            },
+            bases=(models.Model,),
+        ),
+        migrations.CreateModel(
+            name='DateTimeArrayModel',
+            fields=[
+                ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
+                ('field', django.contrib.postgres.fields.ArrayField(models.DateTimeField(), size=None)),
+            ],
+            options={
+            },
+            bases=(models.Model,),
+        ),
+        migrations.CreateModel(
+            name='HStoreModel',
+            fields=[
+                ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
+                ('field', django.contrib.postgres.fields.hstore.HStoreField(blank=True, null=True)),
+            ],
+            options={
+            },
+            bases=(models.Model,),
+        ),
+        migrations.CreateModel(
+            name='IntegerArrayModel',
+            fields=[
+                ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
+                ('field', django.contrib.postgres.fields.ArrayField(models.IntegerField(), size=None)),
+            ],
+            options={
+            },
+            bases=(models.Model,),
+        ),
+        migrations.CreateModel(
+            name='NestedIntegerArrayModel',
+            fields=[
+                ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
+                ('field', django.contrib.postgres.fields.ArrayField(django.contrib.postgres.fields.ArrayField(models.IntegerField(), size=None), size=None)),
+            ],
+            options={
+            },
+            bases=(models.Model,),
+        ),
+        migrations.CreateModel(
+            name='NullableIntegerArrayModel',
+            fields=[
+                ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
+                ('field', django.contrib.postgres.fields.ArrayField(models.IntegerField(), size=None, null=True, blank=True)),
+            ],
+            options={
+            },
+            bases=(models.Model,),
+        ),
+    ]

+ 0 - 0
tests/postgres_tests/migrations/__init__.py


+ 5 - 1
tests/postgres_tests/models.py

@@ -1,4 +1,4 @@
-from django.contrib.postgres.fields import ArrayField
+from django.contrib.postgres.fields import ArrayField, HStoreField
 from django.db import models
 
 
@@ -20,3 +20,7 @@ class DateTimeArrayModel(models.Model):
 
 class NestedIntegerArrayModel(models.Model):
     field = ArrayField(ArrayField(models.IntegerField()))
+
+
+class HStoreModel(models.Model):
+    field = HStoreField(blank=True, null=True)

+ 218 - 0
tests/postgres_tests/test_hstore.py

@@ -0,0 +1,218 @@
+import json
+import unittest
+
+from django.contrib.postgres import forms
+from django.contrib.postgres.fields import HStoreField
+from django.contrib.postgres.validators import KeysValidator
+from django.core import exceptions, serializers
+from django.db import connection
+from django.test import TestCase
+
+from .models import HStoreModel
+
+
+@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required')
+class SimpleTests(TestCase):
+    apps = ['django.contrib.postgres']
+
+    def test_save_load_success(self):
+        value = {'a': 'b'}
+        instance = HStoreModel(field=value)
+        instance.save()
+        reloaded = HStoreModel.objects.get()
+        self.assertEqual(reloaded.field, value)
+
+    def test_null(self):
+        instance = HStoreModel(field=None)
+        instance.save()
+        reloaded = HStoreModel.objects.get()
+        self.assertEqual(reloaded.field, None)
+
+    def test_value_null(self):
+        value = {'a': None}
+        instance = HStoreModel(field=value)
+        instance.save()
+        reloaded = HStoreModel.objects.get()
+        self.assertEqual(reloaded.field, value)
+
+
+@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required')
+class TestQuerying(TestCase):
+
+    def setUp(self):
+        self.objs = [
+            HStoreModel.objects.create(field={'a': 'b'}),
+            HStoreModel.objects.create(field={'a': 'b', 'c': 'd'}),
+            HStoreModel.objects.create(field={'c': 'd'}),
+            HStoreModel.objects.create(field={}),
+            HStoreModel.objects.create(field=None),
+        ]
+
+    def test_exact(self):
+        self.assertSequenceEqual(
+            HStoreModel.objects.filter(field__exact={'a': 'b'}),
+            self.objs[:1]
+        )
+
+    def test_contained_by(self):
+        self.assertSequenceEqual(
+            HStoreModel.objects.filter(field__contained_by={'a': 'b', 'c': 'd'}),
+            self.objs[:4]
+        )
+
+    def test_contains(self):
+        self.assertSequenceEqual(
+            HStoreModel.objects.filter(field__contains={'a': 'b'}),
+            self.objs[:2]
+        )
+
+    def test_has_key(self):
+        self.assertSequenceEqual(
+            HStoreModel.objects.filter(field__has_key='c'),
+            self.objs[1:3]
+        )
+
+    def test_has_keys(self):
+        self.assertSequenceEqual(
+            HStoreModel.objects.filter(field__has_keys=['a', 'c']),
+            self.objs[1:2]
+        )
+
+    def test_key_transform(self):
+        self.assertSequenceEqual(
+            HStoreModel.objects.filter(field__a='b'),
+            self.objs[:2]
+        )
+
+    def test_keys(self):
+        self.assertSequenceEqual(
+            HStoreModel.objects.filter(field__keys=['a']),
+            self.objs[:1]
+        )
+
+    def test_values(self):
+        self.assertSequenceEqual(
+            HStoreModel.objects.filter(field__values=['b']),
+            self.objs[:1]
+        )
+
+    def test_field_chaining(self):
+        self.assertSequenceEqual(
+            HStoreModel.objects.filter(field__a__contains='b'),
+            self.objs[:2]
+        )
+
+    def test_keys_contains(self):
+        self.assertSequenceEqual(
+            HStoreModel.objects.filter(field__keys__contains=['a']),
+            self.objs[:2]
+        )
+
+    def test_values_overlap(self):
+        self.assertSequenceEqual(
+            HStoreModel.objects.filter(field__values__overlap=['b', 'd']),
+            self.objs[:3]
+        )
+
+
+@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL required')
+class TestSerialization(TestCase):
+    test_data = '[{"fields": {"field": "{\\"a\\": \\"b\\"}"}, "model": "postgres_tests.hstoremodel", "pk": null}]'
+
+    def test_dumping(self):
+        instance = HStoreModel(field={'a': 'b'})
+        data = serializers.serialize('json', [instance])
+        self.assertEqual(json.loads(data), json.loads(self.test_data))
+
+    def test_loading(self):
+        instance = list(serializers.deserialize('json', self.test_data))[0].object
+        self.assertEqual(instance.field, {'a': 'b'})
+
+
+class TestValidation(TestCase):
+
+    def test_not_a_string(self):
+        field = HStoreField()
+        with self.assertRaises(exceptions.ValidationError) as cm:
+            field.clean({'a': 1}, None)
+        self.assertEqual(cm.exception.code, 'not_a_string')
+        self.assertEqual(cm.exception.message % cm.exception.params, 'The value of "a" is not a string.')
+
+
+class TestFormField(TestCase):
+
+    def test_valid(self):
+        field = forms.HStoreField()
+        value = field.clean('{"a": "b"}')
+        self.assertEqual(value, {'a': 'b'})
+
+    def test_invalid_json(self):
+        field = forms.HStoreField()
+        with self.assertRaises(exceptions.ValidationError) as cm:
+            field.clean('{"a": "b"')
+        self.assertEqual(cm.exception.messages[0], 'Could not load JSON data.')
+        self.assertEqual(cm.exception.code, 'invalid_json')
+
+    def test_not_string_values(self):
+        field = forms.HStoreField()
+        value = field.clean('{"a": 1}')
+        self.assertEqual(value, {'a': '1'})
+
+    def test_empty(self):
+        field = forms.HStoreField(required=False)
+        value = field.clean('')
+        self.assertEqual(value, {})
+
+    def test_model_field_formfield(self):
+        model_field = HStoreField()
+        form_field = model_field.formfield()
+        self.assertIsInstance(form_field, forms.HStoreField)
+
+
+class TestValidator(TestCase):
+
+    def test_simple_valid(self):
+        validator = KeysValidator(keys=['a', 'b'])
+        validator({'a': 'foo', 'b': 'bar', 'c': 'baz'})
+
+    def test_missing_keys(self):
+        validator = KeysValidator(keys=['a', 'b'])
+        with self.assertRaises(exceptions.ValidationError) as cm:
+            validator({'a': 'foo', 'c': 'baz'})
+        self.assertEqual(cm.exception.messages[0], 'Some keys were missing: b')
+        self.assertEqual(cm.exception.code, 'missing_keys')
+
+    def test_strict_valid(self):
+        validator = KeysValidator(keys=['a', 'b'], strict=True)
+        validator({'a': 'foo', 'b': 'bar'})
+
+    def test_extra_keys(self):
+        validator = KeysValidator(keys=['a', 'b'], strict=True)
+        with self.assertRaises(exceptions.ValidationError) as cm:
+            validator({'a': 'foo', 'b': 'bar', 'c': 'baz'})
+        self.assertEqual(cm.exception.messages[0], 'Some unknown keys were provided: c')
+        self.assertEqual(cm.exception.code, 'extra_keys')
+
+    def test_custom_messages(self):
+        messages = {
+            'missing_keys': 'Foobar',
+        }
+        validator = KeysValidator(keys=['a', 'b'], strict=True, messages=messages)
+        with self.assertRaises(exceptions.ValidationError) as cm:
+            validator({'a': 'foo', 'c': 'baz'})
+        self.assertEqual(cm.exception.messages[0], 'Foobar')
+        self.assertEqual(cm.exception.code, 'missing_keys')
+        with self.assertRaises(exceptions.ValidationError) as cm:
+            validator({'a': 'foo', 'b': 'bar', 'c': 'baz'})
+        self.assertEqual(cm.exception.messages[0], 'Some unknown keys were provided: c')
+        self.assertEqual(cm.exception.code, 'extra_keys')
+
+    def test_deconstruct(self):
+        messages = {
+            'missing_keys': 'Foobar',
+        }
+        validator = KeysValidator(keys=['a', 'b'], strict=True, messages=messages)
+        path, args, kwargs = validator.deconstruct()
+        self.assertEqual(path, 'django.contrib.postgres.validators.KeysValidator')
+        self.assertEqual(args, ())
+        self.assertEqual(kwargs, {'keys': ['a', 'b'], 'strict': True, 'messages': messages})

+ 1 - 1
tests/runtests.py

@@ -78,7 +78,7 @@ def get_test_modules():
                     os.path.isfile(f) or
                     not os.path.exists(os.path.join(dirpath, f, '__init__.py'))):
                 continue
-            if not connection.vendor == 'postgresql' and f == 'postgres_tests':
+            if not connection.vendor == 'postgresql' and f == 'postgres_tests' or f == 'postgres':
                 continue
             modules.append((modpath, f))
     return modules