Browse Source

Fixed #32046 -- Added CreateCollation/RemoveCollation operations for PostgreSQL.

Tom Carrick 4 years ago
parent
commit
f5e07601b2

+ 97 - 0
django/contrib/postgres/operations.py

@@ -164,3 +164,100 @@ class RemoveIndexConcurrently(NotInTransactionMixin, RemoveIndex):
             to_model_state = to_state.models[app_label, self.model_name_lower]
             to_model_state = to_state.models[app_label, self.model_name_lower]
             index = to_model_state.get_index_by_name(self.name)
             index = to_model_state.get_index_by_name(self.name)
             schema_editor.add_index(model, index, concurrently=True)
             schema_editor.add_index(model, index, concurrently=True)
+
+
+class CollationOperation(Operation):
+    def __init__(self, name, locale, *, provider='libc', deterministic=True):
+        self.name = name
+        self.locale = locale
+        self.provider = provider
+        self.deterministic = deterministic
+
+    def state_forwards(self, app_label, state):
+        pass
+
+    def deconstruct(self):
+        kwargs = {'name': self.name, 'locale': self.locale}
+        if self.provider and self.provider != 'libc':
+            kwargs['provider'] = self.provider
+        if self.deterministic is False:
+            kwargs['deterministic'] = self.deterministic
+        return (
+            self.__class__.__qualname__,
+            [],
+            kwargs,
+        )
+
+    def create_collation(self, schema_editor):
+        if (
+            self.deterministic is False and
+            not schema_editor.connection.features.supports_non_deterministic_collations
+        ):
+            raise NotSupportedError(
+                'Non-deterministic collations require PostgreSQL 12+.'
+            )
+        if (
+            self.provider != 'libc' and
+            not schema_editor.connection.features.supports_alternate_collation_providers
+        ):
+            raise NotSupportedError('Non-libc providers require PostgreSQL 10+.')
+        args = {'locale': schema_editor.quote_name(self.locale)}
+        if self.provider != 'libc':
+            args['provider'] = schema_editor.quote_name(self.provider)
+        if self.deterministic is False:
+            args['deterministic'] = 'false'
+        schema_editor.execute('CREATE COLLATION %(name)s (%(args)s)' % {
+            'name': schema_editor.quote_name(self.name),
+            'args': ', '.join(f'{option}={value}' for option, value in args.items()),
+        })
+
+    def remove_collation(self, schema_editor):
+        schema_editor.execute(
+            'DROP COLLATION %s' % schema_editor.quote_name(self.name),
+        )
+
+
+class CreateCollation(CollationOperation):
+    """Create a collation."""
+    def database_forwards(self, app_label, schema_editor, from_state, to_state):
+        if (
+            schema_editor.connection.vendor != 'postgresql' or
+            not router.allow_migrate(schema_editor.connection.alias, app_label)
+        ):
+            return
+        self.create_collation(schema_editor)
+
+    def database_backwards(self, app_label, schema_editor, from_state, to_state):
+        if not router.allow_migrate(schema_editor.connection.alias, app_label):
+            return
+        self.remove_collation(schema_editor)
+
+    def describe(self):
+        return f'Create collation {self.name}'
+
+    @property
+    def migration_name_fragment(self):
+        return 'create_collation_%s' % self.name.lower()
+
+
+class RemoveCollation(CollationOperation):
+    """Remove a collation."""
+    def database_forwards(self, app_label, schema_editor, from_state, to_state):
+        if (
+            schema_editor.connection.vendor != 'postgresql' or
+            not router.allow_migrate(schema_editor.connection.alias, app_label)
+        ):
+            return
+        self.remove_collation(schema_editor)
+
+    def database_backwards(self, app_label, schema_editor, from_state, to_state):
+        if not router.allow_migrate(schema_editor.connection.alias, app_label):
+            return
+        self.create_collation(schema_editor)
+
+    def describe(self):
+        return f'Remove collation {self.name}'
+
+    @property
+    def migration_name_fragment(self):
+        return 'remove_collation_%s' % self.name.lower()

+ 1 - 0
django/db/backends/postgresql/features.py

@@ -100,3 +100,4 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     supports_covering_indexes = property(operator.attrgetter('is_postgresql_11'))
     supports_covering_indexes = property(operator.attrgetter('is_postgresql_11'))
     supports_covering_gist_indexes = property(operator.attrgetter('is_postgresql_12'))
     supports_covering_gist_indexes = property(operator.attrgetter('is_postgresql_12'))
     supports_non_deterministic_collations = property(operator.attrgetter('is_postgresql_12'))
     supports_non_deterministic_collations = property(operator.attrgetter('is_postgresql_12'))
+    supports_alternate_collation_providers = property(operator.attrgetter('is_postgresql_10'))

+ 10 - 0
docs/ref/contrib/postgres/fields.txt

@@ -285,6 +285,16 @@ transform do not change. For example::
     .. _citext: https://www.postgresql.org/docs/current/citext.html
     .. _citext: https://www.postgresql.org/docs/current/citext.html
     .. _the performance considerations: https://www.postgresql.org/docs/current/citext.html#id-1.11.7.17.7
     .. _the performance considerations: https://www.postgresql.org/docs/current/citext.html#id-1.11.7.17.7
 
 
+.. admonition:: Case-insensitive collations
+
+    On PostgreSQL 12+, it's preferable to use non-deterministic collations
+    instead of the ``citext`` extension. You can create them using the
+    :class:`~django.contrib.postgres.operations.CreateCollation` migration
+    operation. For more details, see :ref:`manage-postgresql-collations` and
+    the PostgreSQL documentation about `non-deterministic collations`_.
+
+    .. _non-deterministic collations: https://www.postgresql.org/docs/current/collation.html#COLLATION-NONDETERMINISTIC
+
 ``HStoreField``
 ``HStoreField``
 ===============
 ===============
 
 

+ 50 - 0
docs/ref/contrib/postgres/operations.txt

@@ -115,6 +115,56 @@ them. In that case, connect to your Django database and run the query
 
 
     Installs the ``unaccent`` extension.
     Installs the ``unaccent`` extension.
 
 
+.. _manage-postgresql-collations:
+
+Managing collations using migrations
+====================================
+
+.. versionadded:: 3.2
+
+If you need to filter or order a column using a particular collation that your
+operating system provides but PostgreSQL does not, you can manage collations in
+your database using a migration file. These collations can then be used with
+the ``db_collation`` parameter on :class:`~django.db.models.CharField`,
+:class:`~django.db.models.TextField`, and their subclasses.
+
+For example, to create a collation for German phone book ordering::
+
+    from django.contrib.postgres.operations import CreateCollation
+
+    class Migration(migrations.Migration):
+        ...
+
+        operations = [
+            CreateCollation(
+                'german_phonebook',
+                provider='icu',
+                locale='und-u-ks-level2',
+            ),
+            ...
+        ]
+
+.. class:: CreateCollation(name, locale, *, provider='libc', deterministic=True)
+
+    Creates a collation with the given ``name``, ``locale`` and ``provider``.
+
+    Set the ``deterministic`` parameter to ``False`` to create a
+    non-deterministic collation, such as for case-insensitive filtering.
+
+.. class:: RemoveCollation(name, locale, *, provider='libc', deterministic=True)
+
+    Removes the collations named ``name``.
+
+    When reversed this is creating a collation with the provided ``locale``,
+    ``provider``, and ``deterministic`` arguments. Therefore, ``locale`` is
+    required to make this operation reversible.
+
+.. admonition:: Restrictions
+
+    PostgreSQL 9.6 only supports the ``'libc'`` provider.
+
+    Non-deterministic collations are supported only on PostgreSQL 12+.
+
 Concurrent index operations
 Concurrent index operations
 ===========================
 ===========================
 
 

+ 5 - 0
docs/releases/3.2.txt

@@ -135,6 +135,11 @@ Minor features
   now checks that the extension already exists in the database and skips the
   now checks that the extension already exists in the database and skips the
   migration if so.
   migration if so.
 
 
+* The new :class:`~django.contrib.postgres.operations.CreateCollation` and
+  :class:`~django.contrib.postgres.operations.RemoveCollation` operations
+  allow creating and dropping collations on PostgreSQL. See
+  :ref:`manage-postgresql-collations` for more details.
+
 :mod:`django.contrib.redirects`
 :mod:`django.contrib.redirects`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 

+ 203 - 5
tests/postgres_tests/test_operations.py

@@ -1,11 +1,13 @@
 import unittest
 import unittest
+from unittest import mock
 
 
 from migrations.test_base import OperationTestBase
 from migrations.test_base import OperationTestBase
 
 
 from django.db import NotSupportedError, connection
 from django.db import NotSupportedError, connection
 from django.db.migrations.state import ProjectState
 from django.db.migrations.state import ProjectState
 from django.db.models import Index
 from django.db.models import Index
-from django.test import modify_settings, override_settings
+from django.db.utils import ProgrammingError
+from django.test import modify_settings, override_settings, skipUnlessDBFeature
 from django.test.utils import CaptureQueriesContext
 from django.test.utils import CaptureQueriesContext
 
 
 from . import PostgreSQLTestCase
 from . import PostgreSQLTestCase
@@ -13,8 +15,8 @@ from . import PostgreSQLTestCase
 try:
 try:
     from django.contrib.postgres.indexes import BrinIndex, BTreeIndex
     from django.contrib.postgres.indexes import BrinIndex, BTreeIndex
     from django.contrib.postgres.operations import (
     from django.contrib.postgres.operations import (
-        AddIndexConcurrently, BloomExtension, CreateExtension,
-        RemoveIndexConcurrently,
+        AddIndexConcurrently, BloomExtension, CreateCollation, CreateExtension,
+        RemoveCollation, RemoveIndexConcurrently,
     )
     )
 except ImportError:
 except ImportError:
     pass
     pass
@@ -148,7 +150,7 @@ class RemoveIndexConcurrentlyTests(OperationTestBase):
         self.assertEqual(kwargs, {'model_name': 'Pony', 'name': 'pony_pink_idx'})
         self.assertEqual(kwargs, {'model_name': 'Pony', 'name': 'pony_pink_idx'})
 
 
 
 
-class NoExtensionRouter():
+class NoMigrationRouter():
     def allow_migrate(self, db, app_label, **hints):
     def allow_migrate(self, db, app_label, **hints):
         return False
         return False
 
 
@@ -157,7 +159,7 @@ class NoExtensionRouter():
 class CreateExtensionTests(PostgreSQLTestCase):
 class CreateExtensionTests(PostgreSQLTestCase):
     app_label = 'test_allow_create_extention'
     app_label = 'test_allow_create_extention'
 
 
-    @override_settings(DATABASE_ROUTERS=[NoExtensionRouter()])
+    @override_settings(DATABASE_ROUTERS=[NoMigrationRouter()])
     def test_no_allow_migrate(self):
     def test_no_allow_migrate(self):
         operation = CreateExtension('tablefunc')
         operation = CreateExtension('tablefunc')
         project_state = ProjectState()
         project_state = ProjectState()
@@ -213,3 +215,199 @@ class CreateExtensionTests(PostgreSQLTestCase):
                 operation.database_backwards(self.app_label, editor, project_state, new_state)
                 operation.database_backwards(self.app_label, editor, project_state, new_state)
         self.assertEqual(len(captured_queries), 1)
         self.assertEqual(len(captured_queries), 1)
         self.assertIn('SELECT', captured_queries[0]['sql'])
         self.assertIn('SELECT', captured_queries[0]['sql'])
+
+
+@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
+class CreateCollationTests(PostgreSQLTestCase):
+    app_label = 'test_allow_create_collation'
+
+    @override_settings(DATABASE_ROUTERS=[NoMigrationRouter()])
+    def test_no_allow_migrate(self):
+        operation = CreateCollation('C_test', locale='C')
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        # Don't create a collation.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_forwards(self.app_label, editor, project_state, new_state)
+        self.assertEqual(len(captured_queries), 0)
+        # Reversal.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_backwards(self.app_label, editor, new_state, project_state)
+        self.assertEqual(len(captured_queries), 0)
+
+    def test_create(self):
+        operation = CreateCollation('C_test', locale='C')
+        self.assertEqual(operation.migration_name_fragment, 'create_collation_c_test')
+        self.assertEqual(operation.describe(), 'Create collation C_test')
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        # Create a collation.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_forwards(self.app_label, editor, project_state, new_state)
+        self.assertEqual(len(captured_queries), 1)
+        self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
+        # Creating the same collation raises an exception.
+        with self.assertRaisesMessage(ProgrammingError, 'already exists'):
+            with connection.schema_editor(atomic=True) as editor:
+                operation.database_forwards(self.app_label, editor, project_state, new_state)
+        # Reversal.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_backwards(self.app_label, editor, new_state, project_state)
+        self.assertEqual(len(captured_queries), 1)
+        self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
+        # Deconstruction.
+        name, args, kwargs = operation.deconstruct()
+        self.assertEqual(name, 'CreateCollation')
+        self.assertEqual(args, [])
+        self.assertEqual(kwargs, {'name': 'C_test', 'locale': 'C'})
+
+    @skipUnlessDBFeature('supports_non_deterministic_collations')
+    def test_create_non_deterministic_collation(self):
+        operation = CreateCollation(
+            'case_insensitive_test',
+            'und-u-ks-level2',
+            provider='icu',
+            deterministic=False,
+        )
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        # Create a collation.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_forwards(self.app_label, editor, project_state, new_state)
+        self.assertEqual(len(captured_queries), 1)
+        self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
+        # Reversal.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_backwards(self.app_label, editor, new_state, project_state)
+        self.assertEqual(len(captured_queries), 1)
+        self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
+        # Deconstruction.
+        name, args, kwargs = operation.deconstruct()
+        self.assertEqual(name, 'CreateCollation')
+        self.assertEqual(args, [])
+        self.assertEqual(kwargs, {
+            'name': 'case_insensitive_test',
+            'locale': 'und-u-ks-level2',
+            'provider': 'icu',
+            'deterministic': False,
+        })
+
+    @skipUnlessDBFeature('supports_alternate_collation_providers')
+    def test_create_collation_alternate_provider(self):
+        operation = CreateCollation(
+            'german_phonebook_test',
+            provider='icu',
+            locale='de-u-co-phonebk',
+        )
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        # Create an collation.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_forwards(self.app_label, editor, project_state, new_state)
+        self.assertEqual(len(captured_queries), 1)
+        self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
+        # Reversal.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_backwards(self.app_label, editor, new_state, project_state)
+        self.assertEqual(len(captured_queries), 1)
+        self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
+
+    def test_nondeterministic_collation_not_supported(self):
+        operation = CreateCollation(
+            'case_insensitive_test',
+            provider='icu',
+            locale='und-u-ks-level2',
+            deterministic=False,
+        )
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        msg = 'Non-deterministic collations require PostgreSQL 12+.'
+        with connection.schema_editor(atomic=False) as editor:
+            with mock.patch(
+                'django.db.backends.postgresql.features.DatabaseFeatures.'
+                'supports_non_deterministic_collations',
+                False,
+            ):
+                with self.assertRaisesMessage(NotSupportedError, msg):
+                    operation.database_forwards(self.app_label, editor, project_state, new_state)
+
+    def test_collation_with_icu_provider_raises_error(self):
+        operation = CreateCollation(
+            'german_phonebook',
+            provider='icu',
+            locale='de-u-co-phonebk',
+        )
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        msg = 'Non-libc providers require PostgreSQL 10+.'
+        with connection.schema_editor(atomic=False) as editor:
+            with mock.patch(
+                'django.db.backends.postgresql.features.DatabaseFeatures.'
+                'supports_alternate_collation_providers',
+                False,
+            ):
+                with self.assertRaisesMessage(NotSupportedError, msg):
+                    operation.database_forwards(self.app_label, editor, project_state, new_state)
+
+
+@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
+class RemoveCollationTests(PostgreSQLTestCase):
+    app_label = 'test_allow_remove_collation'
+
+    @override_settings(DATABASE_ROUTERS=[NoMigrationRouter()])
+    def test_no_allow_migrate(self):
+        operation = RemoveCollation('C_test', locale='C')
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        # Don't create a collation.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_forwards(self.app_label, editor, project_state, new_state)
+        self.assertEqual(len(captured_queries), 0)
+        # Reversal.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_backwards(self.app_label, editor, new_state, project_state)
+        self.assertEqual(len(captured_queries), 0)
+
+    def test_remove(self):
+        operation = CreateCollation('C_test', locale='C')
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        with connection.schema_editor(atomic=False) as editor:
+            operation.database_forwards(self.app_label, editor, project_state, new_state)
+
+        operation = RemoveCollation('C_test', locale='C')
+        self.assertEqual(operation.migration_name_fragment, 'remove_collation_c_test')
+        self.assertEqual(operation.describe(), 'Remove collation C_test')
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        # Remove a collation.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_forwards(self.app_label, editor, project_state, new_state)
+        self.assertEqual(len(captured_queries), 1)
+        self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
+        # Removing a nonexistent collation raises an exception.
+        with self.assertRaisesMessage(ProgrammingError, 'does not exist'):
+            with connection.schema_editor(atomic=True) as editor:
+                operation.database_forwards(self.app_label, editor, project_state, new_state)
+        # Reversal.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_backwards(self.app_label, editor, new_state, project_state)
+        self.assertEqual(len(captured_queries), 1)
+        self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
+        # Deconstruction.
+        name, args, kwargs = operation.deconstruct()
+        self.assertEqual(name, 'RemoveCollation')
+        self.assertEqual(args, [])
+        self.assertEqual(kwargs, {'name': 'C_test', 'locale': 'C'})