Procházet zdrojové kódy

Fixed #32046 -- Added CreateCollation/RemoveCollation operations for PostgreSQL.

Tom Carrick před 4 roky
rodič
revize
f5e07601b2

+ 97 - 0
django/contrib/postgres/operations.py

@@ -164,3 +164,100 @@ class RemoveIndexConcurrently(NotInTransactionMixin, RemoveIndex):
             to_model_state = to_state.models[app_label, self.model_name_lower]
             index = to_model_state.get_index_by_name(self.name)
             schema_editor.add_index(model, index, concurrently=True)
+
+
+class CollationOperation(Operation):
+    def __init__(self, name, locale, *, provider='libc', deterministic=True):
+        self.name = name
+        self.locale = locale
+        self.provider = provider
+        self.deterministic = deterministic
+
+    def state_forwards(self, app_label, state):
+        pass
+
+    def deconstruct(self):
+        kwargs = {'name': self.name, 'locale': self.locale}
+        if self.provider and self.provider != 'libc':
+            kwargs['provider'] = self.provider
+        if self.deterministic is False:
+            kwargs['deterministic'] = self.deterministic
+        return (
+            self.__class__.__qualname__,
+            [],
+            kwargs,
+        )
+
+    def create_collation(self, schema_editor):
+        if (
+            self.deterministic is False and
+            not schema_editor.connection.features.supports_non_deterministic_collations
+        ):
+            raise NotSupportedError(
+                'Non-deterministic collations require PostgreSQL 12+.'
+            )
+        if (
+            self.provider != 'libc' and
+            not schema_editor.connection.features.supports_alternate_collation_providers
+        ):
+            raise NotSupportedError('Non-libc providers require PostgreSQL 10+.')
+        args = {'locale': schema_editor.quote_name(self.locale)}
+        if self.provider != 'libc':
+            args['provider'] = schema_editor.quote_name(self.provider)
+        if self.deterministic is False:
+            args['deterministic'] = 'false'
+        schema_editor.execute('CREATE COLLATION %(name)s (%(args)s)' % {
+            'name': schema_editor.quote_name(self.name),
+            'args': ', '.join(f'{option}={value}' for option, value in args.items()),
+        })
+
+    def remove_collation(self, schema_editor):
+        schema_editor.execute(
+            'DROP COLLATION %s' % schema_editor.quote_name(self.name),
+        )
+
+
+class CreateCollation(CollationOperation):
+    """Create a collation."""
+    def database_forwards(self, app_label, schema_editor, from_state, to_state):
+        if (
+            schema_editor.connection.vendor != 'postgresql' or
+            not router.allow_migrate(schema_editor.connection.alias, app_label)
+        ):
+            return
+        self.create_collation(schema_editor)
+
+    def database_backwards(self, app_label, schema_editor, from_state, to_state):
+        if not router.allow_migrate(schema_editor.connection.alias, app_label):
+            return
+        self.remove_collation(schema_editor)
+
+    def describe(self):
+        return f'Create collation {self.name}'
+
+    @property
+    def migration_name_fragment(self):
+        return 'create_collation_%s' % self.name.lower()
+
+
+class RemoveCollation(CollationOperation):
+    """Remove a collation."""
+    def database_forwards(self, app_label, schema_editor, from_state, to_state):
+        if (
+            schema_editor.connection.vendor != 'postgresql' or
+            not router.allow_migrate(schema_editor.connection.alias, app_label)
+        ):
+            return
+        self.remove_collation(schema_editor)
+
+    def database_backwards(self, app_label, schema_editor, from_state, to_state):
+        if not router.allow_migrate(schema_editor.connection.alias, app_label):
+            return
+        self.create_collation(schema_editor)
+
+    def describe(self):
+        return f'Remove collation {self.name}'
+
+    @property
+    def migration_name_fragment(self):
+        return 'remove_collation_%s' % self.name.lower()

+ 1 - 0
django/db/backends/postgresql/features.py

@@ -100,3 +100,4 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     supports_covering_indexes = property(operator.attrgetter('is_postgresql_11'))
     supports_covering_gist_indexes = property(operator.attrgetter('is_postgresql_12'))
     supports_non_deterministic_collations = property(operator.attrgetter('is_postgresql_12'))
+    supports_alternate_collation_providers = property(operator.attrgetter('is_postgresql_10'))

+ 10 - 0
docs/ref/contrib/postgres/fields.txt

@@ -285,6 +285,16 @@ transform do not change. For example::
     .. _citext: https://www.postgresql.org/docs/current/citext.html
     .. _the performance considerations: https://www.postgresql.org/docs/current/citext.html#id-1.11.7.17.7
 
+.. admonition:: Case-insensitive collations
+
+    On PostgreSQL 12+, it's preferable to use non-deterministic collations
+    instead of the ``citext`` extension. You can create them using the
+    :class:`~django.contrib.postgres.operations.CreateCollation` migration
+    operation. For more details, see :ref:`manage-postgresql-collations` and
+    the PostgreSQL documentation about `non-deterministic collations`_.
+
+    .. _non-deterministic collations: https://www.postgresql.org/docs/current/collation.html#COLLATION-NONDETERMINISTIC
+
 ``HStoreField``
 ===============
 

+ 50 - 0
docs/ref/contrib/postgres/operations.txt

@@ -115,6 +115,56 @@ them. In that case, connect to your Django database and run the query
 
     Installs the ``unaccent`` extension.
 
+.. _manage-postgresql-collations:
+
+Managing collations using migrations
+====================================
+
+.. versionadded:: 3.2
+
+If you need to filter or order a column using a particular collation that your
+operating system provides but PostgreSQL does not, you can manage collations in
+your database using a migration file. These collations can then be used with
+the ``db_collation`` parameter on :class:`~django.db.models.CharField`,
+:class:`~django.db.models.TextField`, and their subclasses.
+
+For example, to create a collation for German phone book ordering::
+
+    from django.contrib.postgres.operations import CreateCollation
+
+    class Migration(migrations.Migration):
+        ...
+
+        operations = [
+            CreateCollation(
+                'german_phonebook',
+                provider='icu',
+                locale='und-u-ks-level2',
+            ),
+            ...
+        ]
+
+.. class:: CreateCollation(name, locale, *, provider='libc', deterministic=True)
+
+    Creates a collation with the given ``name``, ``locale`` and ``provider``.
+
+    Set the ``deterministic`` parameter to ``False`` to create a
+    non-deterministic collation, such as for case-insensitive filtering.
+
+.. class:: RemoveCollation(name, locale, *, provider='libc', deterministic=True)
+
+    Removes the collations named ``name``.
+
+    When reversed this is creating a collation with the provided ``locale``,
+    ``provider``, and ``deterministic`` arguments. Therefore, ``locale`` is
+    required to make this operation reversible.
+
+.. admonition:: Restrictions
+
+    PostgreSQL 9.6 only supports the ``'libc'`` provider.
+
+    Non-deterministic collations are supported only on PostgreSQL 12+.
+
 Concurrent index operations
 ===========================
 

+ 5 - 0
docs/releases/3.2.txt

@@ -135,6 +135,11 @@ Minor features
   now checks that the extension already exists in the database and skips the
   migration if so.
 
+* The new :class:`~django.contrib.postgres.operations.CreateCollation` and
+  :class:`~django.contrib.postgres.operations.RemoveCollation` operations
+  allow creating and dropping collations on PostgreSQL. See
+  :ref:`manage-postgresql-collations` for more details.
+
 :mod:`django.contrib.redirects`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

+ 203 - 5
tests/postgres_tests/test_operations.py

@@ -1,11 +1,13 @@
 import unittest
+from unittest import mock
 
 from migrations.test_base import OperationTestBase
 
 from django.db import NotSupportedError, connection
 from django.db.migrations.state import ProjectState
 from django.db.models import Index
-from django.test import modify_settings, override_settings
+from django.db.utils import ProgrammingError
+from django.test import modify_settings, override_settings, skipUnlessDBFeature
 from django.test.utils import CaptureQueriesContext
 
 from . import PostgreSQLTestCase
@@ -13,8 +15,8 @@ from . import PostgreSQLTestCase
 try:
     from django.contrib.postgres.indexes import BrinIndex, BTreeIndex
     from django.contrib.postgres.operations import (
-        AddIndexConcurrently, BloomExtension, CreateExtension,
-        RemoveIndexConcurrently,
+        AddIndexConcurrently, BloomExtension, CreateCollation, CreateExtension,
+        RemoveCollation, RemoveIndexConcurrently,
     )
 except ImportError:
     pass
@@ -148,7 +150,7 @@ class RemoveIndexConcurrentlyTests(OperationTestBase):
         self.assertEqual(kwargs, {'model_name': 'Pony', 'name': 'pony_pink_idx'})
 
 
-class NoExtensionRouter():
+class NoMigrationRouter():
     def allow_migrate(self, db, app_label, **hints):
         return False
 
@@ -157,7 +159,7 @@ class NoExtensionRouter():
 class CreateExtensionTests(PostgreSQLTestCase):
     app_label = 'test_allow_create_extention'
 
-    @override_settings(DATABASE_ROUTERS=[NoExtensionRouter()])
+    @override_settings(DATABASE_ROUTERS=[NoMigrationRouter()])
     def test_no_allow_migrate(self):
         operation = CreateExtension('tablefunc')
         project_state = ProjectState()
@@ -213,3 +215,199 @@ class CreateExtensionTests(PostgreSQLTestCase):
                 operation.database_backwards(self.app_label, editor, project_state, new_state)
         self.assertEqual(len(captured_queries), 1)
         self.assertIn('SELECT', captured_queries[0]['sql'])
+
+
+@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
+class CreateCollationTests(PostgreSQLTestCase):
+    app_label = 'test_allow_create_collation'
+
+    @override_settings(DATABASE_ROUTERS=[NoMigrationRouter()])
+    def test_no_allow_migrate(self):
+        operation = CreateCollation('C_test', locale='C')
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        # Don't create a collation.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_forwards(self.app_label, editor, project_state, new_state)
+        self.assertEqual(len(captured_queries), 0)
+        # Reversal.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_backwards(self.app_label, editor, new_state, project_state)
+        self.assertEqual(len(captured_queries), 0)
+
+    def test_create(self):
+        operation = CreateCollation('C_test', locale='C')
+        self.assertEqual(operation.migration_name_fragment, 'create_collation_c_test')
+        self.assertEqual(operation.describe(), 'Create collation C_test')
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        # Create a collation.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_forwards(self.app_label, editor, project_state, new_state)
+        self.assertEqual(len(captured_queries), 1)
+        self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
+        # Creating the same collation raises an exception.
+        with self.assertRaisesMessage(ProgrammingError, 'already exists'):
+            with connection.schema_editor(atomic=True) as editor:
+                operation.database_forwards(self.app_label, editor, project_state, new_state)
+        # Reversal.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_backwards(self.app_label, editor, new_state, project_state)
+        self.assertEqual(len(captured_queries), 1)
+        self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
+        # Deconstruction.
+        name, args, kwargs = operation.deconstruct()
+        self.assertEqual(name, 'CreateCollation')
+        self.assertEqual(args, [])
+        self.assertEqual(kwargs, {'name': 'C_test', 'locale': 'C'})
+
+    @skipUnlessDBFeature('supports_non_deterministic_collations')
+    def test_create_non_deterministic_collation(self):
+        operation = CreateCollation(
+            'case_insensitive_test',
+            'und-u-ks-level2',
+            provider='icu',
+            deterministic=False,
+        )
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        # Create a collation.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_forwards(self.app_label, editor, project_state, new_state)
+        self.assertEqual(len(captured_queries), 1)
+        self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
+        # Reversal.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_backwards(self.app_label, editor, new_state, project_state)
+        self.assertEqual(len(captured_queries), 1)
+        self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
+        # Deconstruction.
+        name, args, kwargs = operation.deconstruct()
+        self.assertEqual(name, 'CreateCollation')
+        self.assertEqual(args, [])
+        self.assertEqual(kwargs, {
+            'name': 'case_insensitive_test',
+            'locale': 'und-u-ks-level2',
+            'provider': 'icu',
+            'deterministic': False,
+        })
+
+    @skipUnlessDBFeature('supports_alternate_collation_providers')
+    def test_create_collation_alternate_provider(self):
+        operation = CreateCollation(
+            'german_phonebook_test',
+            provider='icu',
+            locale='de-u-co-phonebk',
+        )
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        # Create an collation.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_forwards(self.app_label, editor, project_state, new_state)
+        self.assertEqual(len(captured_queries), 1)
+        self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
+        # Reversal.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_backwards(self.app_label, editor, new_state, project_state)
+        self.assertEqual(len(captured_queries), 1)
+        self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
+
+    def test_nondeterministic_collation_not_supported(self):
+        operation = CreateCollation(
+            'case_insensitive_test',
+            provider='icu',
+            locale='und-u-ks-level2',
+            deterministic=False,
+        )
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        msg = 'Non-deterministic collations require PostgreSQL 12+.'
+        with connection.schema_editor(atomic=False) as editor:
+            with mock.patch(
+                'django.db.backends.postgresql.features.DatabaseFeatures.'
+                'supports_non_deterministic_collations',
+                False,
+            ):
+                with self.assertRaisesMessage(NotSupportedError, msg):
+                    operation.database_forwards(self.app_label, editor, project_state, new_state)
+
+    def test_collation_with_icu_provider_raises_error(self):
+        operation = CreateCollation(
+            'german_phonebook',
+            provider='icu',
+            locale='de-u-co-phonebk',
+        )
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        msg = 'Non-libc providers require PostgreSQL 10+.'
+        with connection.schema_editor(atomic=False) as editor:
+            with mock.patch(
+                'django.db.backends.postgresql.features.DatabaseFeatures.'
+                'supports_alternate_collation_providers',
+                False,
+            ):
+                with self.assertRaisesMessage(NotSupportedError, msg):
+                    operation.database_forwards(self.app_label, editor, project_state, new_state)
+
+
+@unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific tests.')
+class RemoveCollationTests(PostgreSQLTestCase):
+    app_label = 'test_allow_remove_collation'
+
+    @override_settings(DATABASE_ROUTERS=[NoMigrationRouter()])
+    def test_no_allow_migrate(self):
+        operation = RemoveCollation('C_test', locale='C')
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        # Don't create a collation.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_forwards(self.app_label, editor, project_state, new_state)
+        self.assertEqual(len(captured_queries), 0)
+        # Reversal.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_backwards(self.app_label, editor, new_state, project_state)
+        self.assertEqual(len(captured_queries), 0)
+
+    def test_remove(self):
+        operation = CreateCollation('C_test', locale='C')
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        with connection.schema_editor(atomic=False) as editor:
+            operation.database_forwards(self.app_label, editor, project_state, new_state)
+
+        operation = RemoveCollation('C_test', locale='C')
+        self.assertEqual(operation.migration_name_fragment, 'remove_collation_c_test')
+        self.assertEqual(operation.describe(), 'Remove collation C_test')
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        # Remove a collation.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_forwards(self.app_label, editor, project_state, new_state)
+        self.assertEqual(len(captured_queries), 1)
+        self.assertIn('DROP COLLATION', captured_queries[0]['sql'])
+        # Removing a nonexistent collation raises an exception.
+        with self.assertRaisesMessage(ProgrammingError, 'does not exist'):
+            with connection.schema_editor(atomic=True) as editor:
+                operation.database_forwards(self.app_label, editor, project_state, new_state)
+        # Reversal.
+        with CaptureQueriesContext(connection) as captured_queries:
+            with connection.schema_editor(atomic=False) as editor:
+                operation.database_backwards(self.app_label, editor, new_state, project_state)
+        self.assertEqual(len(captured_queries), 1)
+        self.assertIn('CREATE COLLATION', captured_queries[0]['sql'])
+        # Deconstruction.
+        name, args, kwargs = operation.deconstruct()
+        self.assertEqual(name, 'RemoveCollation')
+        self.assertEqual(args, [])
+        self.assertEqual(kwargs, {'name': 'C_test', 'locale': 'C'})