Browse Source

Fixed #22583 -- Allowed RunPython and RunSQL to provide hints to the db router.

Thanks Markus Holtermann and Tim Graham for the review.
Loic Bistuer 10 năm trước cách đây
mục cha
commit
8f4877c89d

+ 7 - 9
django/db/migrations/operations/base.py

@@ -98,17 +98,15 @@ class Operation(object):
         """
         return self.references_model(model_name, app_label)
 
-    def allowed_to_migrate(self, connection_alias, model):
+    def allowed_to_migrate(self, connection_alias, model, hints=None):
         """
-        Returns if we're allowed to migrate the model. Checks the router,
-        if it's a proxy, if it's managed, and if it's swapped out.
+        Returns if we're allowed to migrate the model.
         """
-        return (
-            not model._meta.proxy and
-            not model._meta.swapped and
-            model._meta.managed and
-            router.allow_migrate(connection_alias, model)
-        )
+        # Always skip if proxy, swapped out, or unmanaged.
+        if model and (model._meta.proxy or model._meta.swapped or not model._meta.managed):
+            return False
+
+        return router.allow_migrate(connection_alias, model, **(hints or {}))
 
     def __repr__(self):
         return "<%s %s%s>" % (

+ 20 - 10
django/db/migrations/operations/special.py

@@ -63,10 +63,11 @@ class RunSQL(Operation):
     """
     noop = ''
 
-    def __init__(self, sql, reverse_sql=None, state_operations=None):
+    def __init__(self, sql, reverse_sql=None, state_operations=None, hints=None):
         self.sql = sql
         self.reverse_sql = reverse_sql
         self.state_operations = state_operations or []
+        self.hints = hints or {}
 
     def deconstruct(self):
         kwargs = {
@@ -76,6 +77,8 @@ class RunSQL(Operation):
             kwargs['reverse_sql'] = self.reverse_sql
         if self.state_operations:
             kwargs['state_operations'] = self.state_operations
+        if self.hints:
+            kwargs['hints'] = self.hints
         return (
             self.__class__.__name__,
             [],
@@ -91,12 +94,14 @@ class RunSQL(Operation):
             state_operation.state_forwards(app_label, state)
 
     def database_forwards(self, app_label, schema_editor, from_state, to_state):
-        self._run_sql(schema_editor, self.sql)
+        if self.allowed_to_migrate(schema_editor.connection.alias, None, hints=self.hints):
+            self._run_sql(schema_editor, self.sql)
 
     def database_backwards(self, app_label, schema_editor, from_state, to_state):
         if self.reverse_sql is None:
             raise NotImplementedError("You cannot reverse this operation")
-        self._run_sql(schema_editor, self.reverse_sql)
+        if self.allowed_to_migrate(schema_editor.connection.alias, None, hints=self.hints):
+            self._run_sql(schema_editor, self.reverse_sql)
 
     def describe(self):
         return "Raw SQL operation"
@@ -125,7 +130,7 @@ class RunPython(Operation):
 
     reduces_to_sql = False
 
-    def __init__(self, code, reverse_code=None, atomic=True):
+    def __init__(self, code, reverse_code=None, atomic=True, hints=None):
         self.atomic = atomic
         # Forwards code
         if not callable(code):
@@ -138,6 +143,7 @@ class RunPython(Operation):
             if not callable(reverse_code):
                 raise ValueError("RunPython must be supplied with callable arguments")
             self.reverse_code = reverse_code
+        self.hints = hints or {}
 
     def deconstruct(self):
         kwargs = {
@@ -147,6 +153,8 @@ class RunPython(Operation):
             kwargs['reverse_code'] = self.reverse_code
         if self.atomic is not True:
             kwargs['atomic'] = self.atomic
+        if self.hints:
+            kwargs['hints'] = self.hints
         return (
             self.__class__.__name__,
             [],
@@ -163,16 +171,18 @@ class RunPython(Operation):
         pass
 
     def database_forwards(self, app_label, schema_editor, from_state, to_state):
-        # We now execute the Python code in a context that contains a 'models'
-        # object, representing the versioned models as an app registry.
-        # We could try to override the global cache, but then people will still
-        # use direct imports, so we go with a documentation approach instead.
-        self.code(from_state.apps, schema_editor)
+        if self.allowed_to_migrate(schema_editor.connection.alias, None, hints=self.hints):
+            # We now execute the Python code in a context that contains a 'models'
+            # object, representing the versioned models as an app registry.
+            # We could try to override the global cache, but then people will still
+            # use direct imports, so we go with a documentation approach instead.
+            self.code(from_state.apps, schema_editor)
 
     def database_backwards(self, app_label, schema_editor, from_state, to_state):
         if self.reverse_code is None:
             raise NotImplementedError("You cannot reverse this operation")
-        self.reverse_code(from_state.apps, schema_editor)
+        if self.allowed_to_migrate(schema_editor.connection.alias, None, hints=self.hints):
+            self.reverse_code(from_state.apps, schema_editor)
 
     def describe(self):
         return "Raw Python operation"

+ 2 - 2
django/db/utils.py

@@ -316,7 +316,7 @@ class ConnectionRouter(object):
                     return allow
         return obj1._state.db == obj2._state.db
 
-    def allow_migrate(self, db, model):
+    def allow_migrate(self, db, model, **hints):
         for router in self.routers:
             try:
                 try:
@@ -331,7 +331,7 @@ class ConnectionRouter(object):
                 # If the router doesn't have a method, skip to the next one.
                 pass
             else:
-                allow = method(db, model)
+                allow = method(db, model, **hints)
                 if allow is not None:
                     return allow
         return True

+ 18 - 2
docs/ref/migration-operations.txt

@@ -206,7 +206,7 @@ Special Operations
 RunSQL
 ------
 
-.. class:: RunSQL(sql, reverse_sql=None, state_operations=None)
+.. class:: RunSQL(sql, reverse_sql=None, state_operations=None, hints=None)
 
 Allows running of arbitrary SQL on the database - useful for more advanced
 features of database backends that Django doesn't support directly, like
@@ -235,6 +235,11 @@ operation here so that the autodetector still has an up-to-date state of the
 model (otherwise, when you next run ``makemigrations``, it won't see any
 operation that adds that field and so will try to run it again).
 
+The optional ``hints`` argument will be passed as ``**hints`` to the
+:meth:`allow_migrate` method of database routers to assist them in making
+routing decisions. See :ref:`topics-db-multi-db-hints` for more details on
+database hints.
+
 .. versionchanged:: 1.7.1
 
     If you want to include literal percent signs in a query without parameters
@@ -245,6 +250,8 @@ operation that adds that field and so will try to run it again).
     The ability to pass parameters to the ``sql`` and ``reverse_sql`` queries
     was added.
 
+    The ``hints`` argument was added.
+
 .. attribute:: RunSQL.noop
 
     .. versionadded:: 1.8
@@ -258,7 +265,7 @@ operation that adds that field and so will try to run it again).
 RunPython
 ---------
 
-.. class:: RunPython(code, reverse_code=None, atomic=True)
+.. class:: RunPython(code, reverse_code=None, atomic=True, hints=None)
 
 Runs custom Python code in a historical context. ``code`` (and ``reverse_code``
 if supplied) should be callable objects that accept two arguments; the first is
@@ -267,6 +274,15 @@ match the operation's place in the project history, and the second is an
 instance of :class:`SchemaEditor
 <django.db.backends.schema.BaseDatabaseSchemaEditor>`.
 
+The optional ``hints`` argument will be passed as ``**hints`` to the
+:meth:`allow_migrate` method of database routers to assist them in making a
+routing decision. See :ref:`topics-db-multi-db-hints` for more details on
+database hints.
+
+.. versionadded:: 1.8
+
+    The ``hints`` argument was added.
+
 You are advised to write the code as a separate function above the ``Migration``
 class in the migration file, and just pass it to ``RunPython``. Here's an
 example of using ``RunPython`` to create some initial objects on a ``Country``

+ 14 - 0
docs/releases/1.8.txt

@@ -462,6 +462,12 @@ Migrations
   attribute/method were added to ease in making ``RunPython`` and ``RunSQL``
   operations reversible.
 
+* The :class:`~django.db.migrations.operations.RunPython` and
+  :class:`~django.db.migrations.operations.RunSQL` operations now accept a
+  ``hints`` parameter that will be passed to :meth:`allow_migrate`. To take
+  advantage of this feature you must ensure that the ``allow_migrate()`` method
+  of all your routers accept ``**hints``.
+
 Models
 ^^^^^^
 
@@ -1029,6 +1035,14 @@ Miscellaneous
 * :func:`django.utils.translation.get_language()` now returns ``None`` instead
   of :setting:`LANGUAGE_CODE` when translations are temporarily deactivated.
 
+* The migration operations :class:`~django.db.migrations.operations.RunPython`
+  and :class:`~django.db.migrations.operations.RunSQL` now call the
+  :meth:`allow_migrate` method of database routers. In these cases the
+  ``model`` argument of ``allow_migrate()`` is set to ``None``, so the router
+  must properly handle this value. This is most useful when used together with
+  the newly introduced ``hints`` parameter for these operations, but it can
+  also be used to disable migrations from running on a particular database.
+
 .. _deprecated-features-1.8:
 
 Features deprecated in 1.8

+ 3 - 3
docs/topics/db/multi-db.txt

@@ -150,7 +150,7 @@ A database Router is a class that provides up to four methods:
     used by foreign key and many to many operations to determine if a
     relation should be allowed between two objects.
 
-.. method:: allow_migrate(db, model)
+.. method:: allow_migrate(db, model, **hints)
 
     Determine if the ``model`` should have tables/indexes created in the
     database with alias ``db``. Return True if the model should be
@@ -293,7 +293,7 @@ send queries for the ``auth`` app to ``auth_db``::
                return True
             return None
 
-        def allow_migrate(self, db, model):
+        def allow_migrate(self, db, model, **hints):
             """
             Make sure the auth app only appears in the 'auth_db'
             database.
@@ -333,7 +333,7 @@ from::
                 return True
             return None
 
-        def allow_migrate(self, db, model):
+        def allow_migrate(self, db, model, **hints):
             """
             All non-auth models end up in this pool.
             """

+ 9 - 11
docs/topics/migrations.txt

@@ -545,28 +545,26 @@ attribute::
             migrations.RunPython(forwards),
         ]
 
-You can also use your database router's ``allow_migrate()`` method, but keep in
-mind that the imported router needs to stay around as long as it is referenced
-inside a migration:
+.. versionadded:: 1.8
+
+You can also provide hints that will be passed to the :meth:`allow_migrate()`
+method of database routers as ``**hints``:
 
 .. snippet::
     :filename: myapp/dbrouters.py
 
     class MyRouter(object):
 
-        def allow_migrate(self, db, model):
-            return db == 'default'
+        def allow_migrate(self, db, model, **hints):
+            if 'target_db' in hints:
+                return db == hints['target_db']
+            return True
 
 Then, to leverage this in your migrations, do the following::
 
     from django.db import migrations
 
-    from myappname.dbrouters import MyRouter
-
     def forwards(apps, schema_editor):
-        MyModel = apps.get_model("myappname", "MyModel")
-        if not MyRouter().allow_migrate(schema_editor.connection.alias, MyModel):
-            return
         # Your migration code goes here
 
     class Migration(migrations.Migration):
@@ -576,7 +574,7 @@ Then, to leverage this in your migrations, do the following::
         ]
 
         operations = [
-            migrations.RunPython(forwards),
+            migrations.RunPython(forwards, hints={'target_db': 'default'}),
         ]
 
 More advanced migrations

+ 174 - 0
tests/migrations/test_multidb.py

@@ -0,0 +1,174 @@
+import unittest
+
+try:
+    import sqlparse
+except ImportError:
+    sqlparse = None
+
+from django.db import migrations, models, connection
+from django.db.migrations.state import ProjectState
+from django.test import override_settings
+
+from .test_operations import OperationTestBase
+
+
+class AgnosticRouter(object):
+    """
+    A router that doesn't have an opinion regarding migrating.
+    """
+    def allow_migrate(self, db, model, **hints):
+        return None
+
+
+class MigrateNothingRouter(object):
+    """
+    A router that doesn't allow migrating.
+    """
+    def allow_migrate(self, db, model, **hints):
+        return False
+
+
+class MigrateEverythingRouter(object):
+    """
+    A router that always allows migrating.
+    """
+    def allow_migrate(self, db, model, **hints):
+        return True
+
+
+class MigrateWhenFooRouter(object):
+    """
+    A router that allows migrating depending on a hint.
+    """
+    def allow_migrate(self, db, model, **hints):
+        return hints.get('foo', False)
+
+
+class MultiDBOperationTests(OperationTestBase):
+    multi_db = True
+
+    def _test_create_model(self, app_label, should_run):
+        """
+        Tests that CreateModel honours multi-db settings.
+        """
+        operation = migrations.CreateModel(
+            "Pony",
+            [("id", models.AutoField(primary_key=True))],
+        )
+        # Test the state alteration
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        operation.state_forwards(app_label, new_state)
+        # Test the database alteration
+        self.assertTableNotExists("%s_pony" % app_label)
+        with connection.schema_editor() as editor:
+            operation.database_forwards(app_label, editor, project_state, new_state)
+        if should_run:
+            self.assertTableExists("%s_pony" % app_label)
+        else:
+            self.assertTableNotExists("%s_pony" % app_label)
+        # And test reversal
+        with connection.schema_editor() as editor:
+            operation.database_backwards(app_label, editor, new_state, project_state)
+        self.assertTableNotExists("%s_pony" % app_label)
+
+    @override_settings(DATABASE_ROUTERS=[AgnosticRouter()])
+    def test_create_model(self):
+        """
+        Test when router doesn't have an opinion (i.e. CreateModel should run).
+        """
+        self._test_create_model("test_mltdb_crmo", should_run=True)
+
+    @override_settings(DATABASE_ROUTERS=[MigrateNothingRouter()])
+    def test_create_model2(self):
+        """
+        Test when router returns False (i.e. CreateModel shouldn't run).
+        """
+        self._test_create_model("test_mltdb_crmo2", should_run=False)
+
+    @override_settings(DATABASE_ROUTERS=[MigrateEverythingRouter()])
+    def test_create_model3(self):
+        """
+        Test when router returns True (i.e. CreateModel should run).
+        """
+        self._test_create_model("test_mltdb_crmo3", should_run=True)
+
+    def test_create_model4(self):
+        """
+        Test multiple routers.
+        """
+        with override_settings(DATABASE_ROUTERS=[AgnosticRouter(), AgnosticRouter()]):
+            self._test_create_model("test_mltdb_crmo4", should_run=True)
+        with override_settings(DATABASE_ROUTERS=[MigrateNothingRouter(), MigrateEverythingRouter()]):
+            self._test_create_model("test_mltdb_crmo4", should_run=False)
+        with override_settings(DATABASE_ROUTERS=[MigrateEverythingRouter(), MigrateNothingRouter()]):
+            self._test_create_model("test_mltdb_crmo4", should_run=True)
+
+    def _test_run_sql(self, app_label, should_run, hints=None):
+        with override_settings(DATABASE_ROUTERS=[MigrateEverythingRouter()]):
+            project_state = self.set_up_test_model(app_label)
+
+        sql = """
+        INSERT INTO {0}_pony (pink, weight) VALUES (1, 3.55);
+        INSERT INTO {0}_pony (pink, weight) VALUES (3, 5.0);
+        """.format(app_label)
+
+        operation = migrations.RunSQL(sql, hints=hints or {})
+        # Test the state alteration does nothing
+        new_state = project_state.clone()
+        operation.state_forwards(app_label, new_state)
+        self.assertEqual(new_state, project_state)
+        # Test the database alteration
+        self.assertEqual(project_state.apps.get_model(app_label, "Pony").objects.count(), 0)
+        with connection.schema_editor() as editor:
+            operation.database_forwards(app_label, editor, project_state, new_state)
+        Pony = project_state.apps.get_model(app_label, "Pony")
+        if should_run:
+            self.assertEqual(Pony.objects.count(), 2)
+        else:
+            self.assertEqual(Pony.objects.count(), 0)
+
+    @unittest.skipIf(sqlparse is None and connection.features.requires_sqlparse_for_splitting, "Missing sqlparse")
+    @override_settings(DATABASE_ROUTERS=[MigrateNothingRouter()])
+    def test_run_sql(self):
+        self._test_run_sql("test_mltdb_runsql", should_run=False)
+
+    @unittest.skipIf(sqlparse is None and connection.features.requires_sqlparse_for_splitting, "Missing sqlparse")
+    @override_settings(DATABASE_ROUTERS=[MigrateWhenFooRouter()])
+    def test_run_sql2(self):
+        self._test_run_sql("test_mltdb_runsql2", should_run=False)
+        self._test_run_sql("test_mltdb_runsql2", should_run=True, hints={'foo': True})
+
+    def _test_run_python(self, app_label, should_run, hints=None):
+        with override_settings(DATABASE_ROUTERS=[MigrateEverythingRouter()]):
+            project_state = self.set_up_test_model(app_label)
+
+        # Create the operation
+        def inner_method(models, schema_editor):
+            Pony = models.get_model(app_label, "Pony")
+            Pony.objects.create(pink=1, weight=3.55)
+            Pony.objects.create(weight=5)
+
+        operation = migrations.RunPython(inner_method, hints=hints or {})
+        # Test the state alteration does nothing
+        new_state = project_state.clone()
+        operation.state_forwards(app_label, new_state)
+        self.assertEqual(new_state, project_state)
+        # Test the database alteration
+        self.assertEqual(project_state.apps.get_model(app_label, "Pony").objects.count(), 0)
+        with connection.schema_editor() as editor:
+            operation.database_forwards(app_label, editor, project_state, new_state)
+        Pony = project_state.apps.get_model(app_label, "Pony")
+        if should_run:
+            self.assertEqual(Pony.objects.count(), 2)
+        else:
+            self.assertEqual(Pony.objects.count(), 0)
+
+    @override_settings(DATABASE_ROUTERS=[MigrateNothingRouter()])
+    def test_run_python(self):
+        self._test_run_python("test_mltdb_runpython", should_run=False)
+
+    @override_settings(DATABASE_ROUTERS=[MigrateWhenFooRouter()])
+    def test_run_python2(self):
+        self._test_run_python("test_mltdb_runpython2", should_run=False)
+        self._test_run_python("test_mltdb_runpython2", should_run=True, hints={'foo': True})

+ 0 - 38
tests/migrations/test_operations.py

@@ -1679,44 +1679,6 @@ class OperationTests(OperationTestBase):
         self.assertEqual(sorted(definition[2]), ["database_operations", "state_operations"])
 
 
-class MigrateNothingRouter(object):
-    """
-    A router that doesn't allow storing any model in any database.
-    """
-    def allow_migrate(self, db, model):
-        return False
-
-
-@override_settings(DATABASE_ROUTERS=[MigrateNothingRouter()])
-class MultiDBOperationTests(MigrationTestBase):
-    multi_db = True
-
-    def test_create_model(self):
-        """
-        Tests that CreateModel honours multi-db settings.
-        """
-        operation = migrations.CreateModel(
-            "Pony",
-            [
-                ("id", models.AutoField(primary_key=True)),
-                ("pink", models.IntegerField(default=1)),
-            ],
-        )
-        # Test the state alteration
-        project_state = ProjectState()
-        new_state = project_state.clone()
-        operation.state_forwards("test_crmo", new_state)
-        # Test the database alteration
-        self.assertTableNotExists("test_crmo_pony")
-        with connection.schema_editor() as editor:
-            operation.database_forwards("test_crmo", editor, project_state, new_state)
-        self.assertTableNotExists("test_crmo_pony")
-        # And test reversal
-        with connection.schema_editor() as editor:
-            operation.database_backwards("test_crmo", editor, new_state, project_state)
-        self.assertTableNotExists("test_crmo_pony")
-
-
 class SwappableOperationTests(OperationTestBase):
     """
     Tests that key operations ignore swappable models