Browse Source

Fixed #28161 -- Fixed return type of ArrayField(CITextField()).

Thanks Tim for the review.
Simon Charette 8 years ago
parent
commit
b91868507a

+ 3 - 3
django/contrib/postgres/apps.py

@@ -5,7 +5,7 @@ from django.db.models import CharField, TextField
 from django.utils.translation import gettext_lazy as _
 
 from .lookups import SearchLookup, TrigramSimilar, Unaccent
-from .signals import register_hstore_handler
+from .signals import register_type_handlers
 
 
 class PostgresConfig(AppConfig):
@@ -16,8 +16,8 @@ class PostgresConfig(AppConfig):
         # Connections may already exist before we are called.
         for conn in connections.all():
             if conn.connection is not None:
-                register_hstore_handler(conn)
-        connection_created.connect(register_hstore_handler)
+                register_type_handlers(conn)
+        connection_created.connect(register_type_handlers)
         CharField.register_lookup(Unaccent)
         TextField.register_lookup(Unaccent)
         CharField.register_lookup(SearchLookup)

+ 5 - 8
django/contrib/postgres/operations.py

@@ -1,4 +1,4 @@
-from django.contrib.postgres.signals import register_hstore_handler
+from django.contrib.postgres.signals import register_type_handlers
 from django.db.migrations.operations.base import Operation
 
 
@@ -15,6 +15,10 @@ class CreateExtension(Operation):
         if schema_editor.connection.vendor != 'postgresql':
             return
         schema_editor.execute("CREATE EXTENSION IF NOT EXISTS %s" % schema_editor.quote_name(self.name))
+        # Registering new type handlers cannot be done before the extension is
+        # installed, otherwise a subsequent data migration would use the same
+        # connection.
+        register_type_handlers(schema_editor.connection)
 
     def database_backwards(self, app_label, schema_editor, from_state, to_state):
         schema_editor.execute("DROP EXTENSION %s" % schema_editor.quote_name(self.name))
@@ -46,13 +50,6 @@ class HStoreExtension(CreateExtension):
     def __init__(self):
         self.name = 'hstore'
 
-    def database_forwards(self, app_label, schema_editor, from_state, to_state):
-        super().database_forwards(app_label, schema_editor, from_state, to_state)
-        # Register hstore straight away as it cannot be done before the
-        # extension is installed, a subsequent data migration would use the
-        # same connection
-        register_hstore_handler(schema_editor.connection)
-
 
 class TrigramExtension(CreateExtension):
 

+ 16 - 1
django/contrib/postgres/signals.py

@@ -1,8 +1,9 @@
+import psycopg2
 from psycopg2 import ProgrammingError
 from psycopg2.extras import register_hstore
 
 
-def register_hstore_handler(connection, **kwargs):
+def register_type_handlers(connection, **kwargs):
     if connection.vendor != 'postgresql':
         return
 
@@ -18,3 +19,17 @@ def register_hstore_handler(connection, **kwargs):
         # This is also needed in order to create the connection in order to
         # install the hstore extension.
         pass
+
+    try:
+        with connection.cursor() as cursor:
+            # Retrieve oids of citext arrays.
+            cursor.execute("SELECT typarray FROM pg_type WHERE typname = 'citext'")
+            oids = tuple(row[0] for row in cursor)
+        array_type = psycopg2.extensions.new_array_type(oids, 'citext[]', psycopg2.STRING)
+        psycopg2.extensions.register_type(array_type, None)
+    except ProgrammingError:
+        # citext is not available on the database.
+        #
+        # The same comments in the except block of the above call to
+        # register_hstore() also apply here.
+        pass

+ 3 - 0
docs/releases/1.11.1.txt

@@ -88,3 +88,6 @@ Bugfixes
 * Fixed a regression where ``Model._state.db`` wasn't set correctly on
   multi-table inheritance parent models after saving a child model
   (:ticket:`28166`).
+
+* Corrected the return type of ``ArrayField(CITextField())`` values retrieved
+  from the database (:ticket:`28161`).

+ 3 - 3
tests/postgres_tests/__init__.py

@@ -12,14 +12,14 @@ class PostgreSQLTestCase(TestCase):
     @classmethod
     def tearDownClass(cls):
         # No need to keep that signal overhead for non PostgreSQL-related tests.
-        from django.contrib.postgres.signals import register_hstore_handler
+        from django.contrib.postgres.signals import register_type_handlers
 
-        connection_created.disconnect(register_hstore_handler)
+        connection_created.disconnect(register_type_handlers)
         super().tearDownClass()
 
 
 @unittest.skipUnless(connection.vendor == 'postgresql', "PostgreSQL specific tests")
 # To locate the widget's template.
 @modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
-class PostgreSQLWidgetTestCase(WidgetTest):
+class PostgreSQLWidgetTestCase(WidgetTest, PostgreSQLTestCase):
     pass

+ 1 - 0
tests/postgres_tests/migrations/0002_create_test_models.py

@@ -142,6 +142,7 @@ class Migration(migrations.Migration):
                 ('name', CICharField(primary_key=True, max_length=255)),
                 ('email', CIEmailField()),
                 ('description', CITextField()),
+                ('array_field', ArrayField(CITextField(), null=True)),
             ],
             options={
                 'required_db_vendor': 'postgresql',

+ 1 - 0
tests/postgres_tests/models.py

@@ -106,6 +106,7 @@ class CITestModel(PostgreSQLModel):
     name = CICharField(primary_key=True, max_length=255)
     email = CIEmailField()
     description = CITextField()
+    array_field = ArrayField(CITextField(), null=True)
 
     def __str__(self):
         return self.name

+ 8 - 0
tests/postgres_tests/test_citext.py

@@ -4,11 +4,13 @@ strings and thus eliminates the need for operations such as iexact and other
 modifiers to enforce use of an index.
 """
 from django.db import IntegrityError
+from django.test.utils import modify_settings
 
 from . import PostgreSQLTestCase
 from .models import CITestModel
 
 
+@modify_settings(INSTALLED_APPS={'append': 'django.contrib.postgres'})
 class CITextTestCase(PostgreSQLTestCase):
 
     @classmethod
@@ -17,6 +19,7 @@ class CITextTestCase(PostgreSQLTestCase):
             name='JoHn',
             email='joHn@johN.com',
             description='Average Joe named JoHn',
+            array_field=['JoE', 'jOhn'],
         )
 
     def test_equal_lowercase(self):
@@ -34,3 +37,8 @@ class CITextTestCase(PostgreSQLTestCase):
         """
         with self.assertRaises(IntegrityError):
             CITestModel.objects.create(name='John')
+
+    def test_array_field(self):
+        instance = CITestModel.objects.get()
+        self.assertEqual(instance.array_field, self.john.array_field)
+        self.assertTrue(CITestModel.objects.filter(array_field__contains=['joe']).exists())