Browse Source

Fixed #22583 -- Allowed RunPython and RunSQL to provide hints to the db router.

Thanks Markus Holtermann and Tim Graham for the review.
Loic Bistuer 10 years ago
parent
commit
8f4877c89d

+ 7 - 9
django/db/migrations/operations/base.py

@@ -98,17 +98,15 @@ class Operation(object):
         """
         return self.references_model(model_name, app_label)
 
-    def allowed_to_migrate(self, connection_alias, model):
+    def allowed_to_migrate(self, connection_alias, model, hints=None):
         """
-        Returns if we're allowed to migrate the model. Checks the router,
-        if it's a proxy, if it's managed, and if it's swapped out.
+        Returns if we're allowed to migrate the model.
         """
-        return (
-            not model._meta.proxy and
-            not model._meta.swapped and
-            model._meta.managed and
-            router.allow_migrate(connection_alias, model)
-        )
+        # Always skip if proxy, swapped out, or unmanaged.
+        if model and (model._meta.proxy or model._meta.swapped or not model._meta.managed):
+            return False
+
+        return router.allow_migrate(connection_alias, model, **(hints or {}))
 
     def __repr__(self):
         return "<%s %s%s>" % (

+ 20 - 10
django/db/migrations/operations/special.py

@@ -63,10 +63,11 @@ class RunSQL(Operation):
     """
     noop = ''
 
-    def __init__(self, sql, reverse_sql=None, state_operations=None):
+    def __init__(self, sql, reverse_sql=None, state_operations=None, hints=None):
         self.sql = sql
         self.reverse_sql = reverse_sql
         self.state_operations = state_operations or []
+        self.hints = hints or {}
 
     def deconstruct(self):
         kwargs = {
@@ -76,6 +77,8 @@ class RunSQL(Operation):
             kwargs['reverse_sql'] = self.reverse_sql
         if self.state_operations:
             kwargs['state_operations'] = self.state_operations
+        if self.hints:
+            kwargs['hints'] = self.hints
         return (
             self.__class__.__name__,
             [],
@@ -91,12 +94,14 @@ class RunSQL(Operation):
             state_operation.state_forwards(app_label, state)
 
     def database_forwards(self, app_label, schema_editor, from_state, to_state):
-        self._run_sql(schema_editor, self.sql)
+        if self.allowed_to_migrate(schema_editor.connection.alias, None, hints=self.hints):
+            self._run_sql(schema_editor, self.sql)
 
     def database_backwards(self, app_label, schema_editor, from_state, to_state):
         if self.reverse_sql is None:
             raise NotImplementedError("You cannot reverse this operation")
-        self._run_sql(schema_editor, self.reverse_sql)
+        if self.allowed_to_migrate(schema_editor.connection.alias, None, hints=self.hints):
+            self._run_sql(schema_editor, self.reverse_sql)
 
     def describe(self):
         return "Raw SQL operation"
@@ -125,7 +130,7 @@ class RunPython(Operation):
 
     reduces_to_sql = False
 
-    def __init__(self, code, reverse_code=None, atomic=True):
+    def __init__(self, code, reverse_code=None, atomic=True, hints=None):
         self.atomic = atomic
         # Forwards code
         if not callable(code):
@@ -138,6 +143,7 @@ class RunPython(Operation):
             if not callable(reverse_code):
                 raise ValueError("RunPython must be supplied with callable arguments")
             self.reverse_code = reverse_code
+        self.hints = hints or {}
 
     def deconstruct(self):
         kwargs = {
@@ -147,6 +153,8 @@ class RunPython(Operation):
             kwargs['reverse_code'] = self.reverse_code
         if self.atomic is not True:
             kwargs['atomic'] = self.atomic
+        if self.hints:
+            kwargs['hints'] = self.hints
         return (
             self.__class__.__name__,
             [],
@@ -163,16 +171,18 @@ class RunPython(Operation):
         pass
 
     def database_forwards(self, app_label, schema_editor, from_state, to_state):
-        # We now execute the Python code in a context that contains a 'models'
-        # object, representing the versioned models as an app registry.
-        # We could try to override the global cache, but then people will still
-        # use direct imports, so we go with a documentation approach instead.
-        self.code(from_state.apps, schema_editor)
+        if self.allowed_to_migrate(schema_editor.connection.alias, None, hints=self.hints):
+            # We now execute the Python code in a context that contains a 'models'
+            # object, representing the versioned models as an app registry.
+            # We could try to override the global cache, but then people will still
+            # use direct imports, so we go with a documentation approach instead.
+            self.code(from_state.apps, schema_editor)
 
     def database_backwards(self, app_label, schema_editor, from_state, to_state):
         if self.reverse_code is None:
             raise NotImplementedError("You cannot reverse this operation")
-        self.reverse_code(from_state.apps, schema_editor)
+        if self.allowed_to_migrate(schema_editor.connection.alias, None, hints=self.hints):
+            self.reverse_code(from_state.apps, schema_editor)
 
     def describe(self):
         return "Raw Python operation"

+ 2 - 2
django/db/utils.py

@@ -316,7 +316,7 @@ class ConnectionRouter(object):
                     return allow
         return obj1._state.db == obj2._state.db
 
-    def allow_migrate(self, db, model):
+    def allow_migrate(self, db, model, **hints):
         for router in self.routers:
             try:
                 try:
@@ -331,7 +331,7 @@ class ConnectionRouter(object):
                 # If the router doesn't have a method, skip to the next one.
                 pass
             else:
-                allow = method(db, model)
+                allow = method(db, model, **hints)
                 if allow is not None:
                     return allow
         return True

+ 18 - 2
docs/ref/migration-operations.txt

@@ -206,7 +206,7 @@ Special Operations
 RunSQL
 ------
 
-.. class:: RunSQL(sql, reverse_sql=None, state_operations=None)
+.. class:: RunSQL(sql, reverse_sql=None, state_operations=None, hints=None)
 
 Allows running of arbitrary SQL on the database - useful for more advanced
 features of database backends that Django doesn't support directly, like
@@ -235,6 +235,11 @@ operation here so that the autodetector still has an up-to-date state of the
 model (otherwise, when you next run ``makemigrations``, it won't see any
 operation that adds that field and so will try to run it again).
 
+The optional ``hints`` argument will be passed as ``**hints`` to the
+:meth:`allow_migrate` method of database routers to assist them in making
+routing decisions. See :ref:`topics-db-multi-db-hints` for more details on
+database hints.
+
 .. versionchanged:: 1.7.1
 
     If you want to include literal percent signs in a query without parameters
@@ -245,6 +250,8 @@ operation that adds that field and so will try to run it again).
     The ability to pass parameters to the ``sql`` and ``reverse_sql`` queries
     was added.
 
+    The ``hints`` argument was added.
+
 .. attribute:: RunSQL.noop
 
     .. versionadded:: 1.8
@@ -258,7 +265,7 @@ operation that adds that field and so will try to run it again).
 RunPython
 ---------
 
-.. class:: RunPython(code, reverse_code=None, atomic=True)
+.. class:: RunPython(code, reverse_code=None, atomic=True, hints=None)
 
 Runs custom Python code in a historical context. ``code`` (and ``reverse_code``
 if supplied) should be callable objects that accept two arguments; the first is
@@ -267,6 +274,15 @@ match the operation's place in the project history, and the second is an
 instance of :class:`SchemaEditor
 <django.db.backends.schema.BaseDatabaseSchemaEditor>`.
 
+The optional ``hints`` argument will be passed as ``**hints`` to the
+:meth:`allow_migrate` method of database routers to assist them in making a
+routing decision. See :ref:`topics-db-multi-db-hints` for more details on
+database hints.
+
+.. versionadded:: 1.8
+
+    The ``hints`` argument was added.
+
 You are advised to write the code as a separate function above the ``Migration``
 class in the migration file, and just pass it to ``RunPython``. Here's an
 example of using ``RunPython`` to create some initial objects on a ``Country``

+ 14 - 0
docs/releases/1.8.txt

@@ -462,6 +462,12 @@ Migrations
   attribute/method were added to ease in making ``RunPython`` and ``RunSQL``
   operations reversible.
 
+* The :class:`~django.db.migrations.operations.RunPython` and
+  :class:`~django.db.migrations.operations.RunSQL` operations now accept a
+  ``hints`` parameter that will be passed to :meth:`allow_migrate`. To take
+  advantage of this feature you must ensure that the ``allow_migrate()`` method
+  of all your routers accept ``**hints``.
+
 Models
 ^^^^^^
 
@@ -1029,6 +1035,14 @@ Miscellaneous
 * :func:`django.utils.translation.get_language()` now returns ``None`` instead
   of :setting:`LANGUAGE_CODE` when translations are temporarily deactivated.
 
+* The migration operations :class:`~django.db.migrations.operations.RunPython`
+  and :class:`~django.db.migrations.operations.RunSQL` now call the
+  :meth:`allow_migrate` method of database routers. In these cases the
+  ``model`` argument of ``allow_migrate()`` is set to ``None``, so the router
+  must properly handle this value. This is most useful when used together with
+  the newly introduced ``hints`` parameter for these operations, but it can
+  also be used to disable migrations from running on a particular database.
+
 .. _deprecated-features-1.8:
 
 Features deprecated in 1.8

+ 3 - 3
docs/topics/db/multi-db.txt

@@ -150,7 +150,7 @@ A database Router is a class that provides up to four methods:
     used by foreign key and many to many operations to determine if a
     relation should be allowed between two objects.
 
-.. method:: allow_migrate(db, model)
+.. method:: allow_migrate(db, model, **hints)
 
     Determine if the ``model`` should have tables/indexes created in the
     database with alias ``db``. Return True if the model should be
@@ -293,7 +293,7 @@ send queries for the ``auth`` app to ``auth_db``::
                return True
             return None
 
-        def allow_migrate(self, db, model):
+        def allow_migrate(self, db, model, **hints):
             """
             Make sure the auth app only appears in the 'auth_db'
             database.
@@ -333,7 +333,7 @@ from::
                 return True
             return None
 
-        def allow_migrate(self, db, model):
+        def allow_migrate(self, db, model, **hints):
             """
             All non-auth models end up in this pool.
             """

+ 9 - 11
docs/topics/migrations.txt

@@ -545,28 +545,26 @@ attribute::
             migrations.RunPython(forwards),
         ]
 
-You can also use your database router's ``allow_migrate()`` method, but keep in
-mind that the imported router needs to stay around as long as it is referenced
-inside a migration:
+.. versionadded:: 1.8
+
+You can also provide hints that will be passed to the :meth:`allow_migrate()`
+method of database routers as ``**hints``:
 
 .. snippet::
     :filename: myapp/dbrouters.py
 
     class MyRouter(object):
 
-        def allow_migrate(self, db, model):
-            return db == 'default'
+        def allow_migrate(self, db, model, **hints):
+            if 'target_db' in hints:
+                return db == hints['target_db']
+            return True
 
 Then, to leverage this in your migrations, do the following::
 
     from django.db import migrations
 
-    from myappname.dbrouters import MyRouter
-
     def forwards(apps, schema_editor):
-        MyModel = apps.get_model("myappname", "MyModel")
-        if not MyRouter().allow_migrate(schema_editor.connection.alias, MyModel):
-            return
         # Your migration code goes here
 
     class Migration(migrations.Migration):
@@ -576,7 +574,7 @@ Then, to leverage this in your migrations, do the following::
         ]
 
         operations = [
-            migrations.RunPython(forwards),
+            migrations.RunPython(forwards, hints={'target_db': 'default'}),
         ]
 
 More advanced migrations

+ 174 - 0
tests/migrations/test_multidb.py

@@ -0,0 +1,174 @@
+import unittest
+
+try:
+    import sqlparse
+except ImportError:
+    sqlparse = None
+
+from django.db import migrations, models, connection
+from django.db.migrations.state import ProjectState
+from django.test import override_settings
+
+from .test_operations import OperationTestBase
+
+
+class AgnosticRouter(object):
+    """
+    A router that doesn't have an opinion regarding migrating.
+    """
+    def allow_migrate(self, db, model, **hints):
+        return None
+
+
+class MigrateNothingRouter(object):
+    """
+    A router that doesn't allow migrating.
+    """
+    def allow_migrate(self, db, model, **hints):
+        return False
+
+
+class MigrateEverythingRouter(object):
+    """
+    A router that always allows migrating.
+    """
+    def allow_migrate(self, db, model, **hints):
+        return True
+
+
+class MigrateWhenFooRouter(object):
+    """
+    A router that allows migrating depending on a hint.
+    """
+    def allow_migrate(self, db, model, **hints):
+        return hints.get('foo', False)
+
+
+class MultiDBOperationTests(OperationTestBase):
+    multi_db = True
+
+    def _test_create_model(self, app_label, should_run):
+        """
+        Tests that CreateModel honours multi-db settings.
+        """
+        operation = migrations.CreateModel(
+            "Pony",
+            [("id", models.AutoField(primary_key=True))],
+        )
+        # Test the state alteration
+        project_state = ProjectState()
+        new_state = project_state.clone()
+        operation.state_forwards(app_label, new_state)
+        # Test the database alteration
+        self.assertTableNotExists("%s_pony" % app_label)
+        with connection.schema_editor() as editor:
+            operation.database_forwards(app_label, editor, project_state, new_state)
+        if should_run:
+            self.assertTableExists("%s_pony" % app_label)
+        else:
+            self.assertTableNotExists("%s_pony" % app_label)
+        # And test reversal
+        with connection.schema_editor() as editor:
+            operation.database_backwards(app_label, editor, new_state, project_state)
+        self.assertTableNotExists("%s_pony" % app_label)
+
+    @override_settings(DATABASE_ROUTERS=[AgnosticRouter()])
+    def test_create_model(self):
+        """
+        Test when router doesn't have an opinion (i.e. CreateModel should run).
+        """
+        self._test_create_model("test_mltdb_crmo", should_run=True)
+
+    @override_settings(DATABASE_ROUTERS=[MigrateNothingRouter()])
+    def test_create_model2(self):
+        """
+        Test when router returns False (i.e. CreateModel shouldn't run).
+        """
+        self._test_create_model("test_mltdb_crmo2", should_run=False)
+
+    @override_settings(DATABASE_ROUTERS=[MigrateEverythingRouter()])
+    def test_create_model3(self):
+        """
+        Test when router returns True (i.e. CreateModel should run).
+        """
+        self._test_create_model("test_mltdb_crmo3", should_run=True)
+
+    def test_create_model4(self):
+        """
+        Test multiple routers.
+        """
+        with override_settings(DATABASE_ROUTERS=[AgnosticRouter(), AgnosticRouter()]):
+            self._test_create_model("test_mltdb_crmo4", should_run=True)
+        with override_settings(DATABASE_ROUTERS=[MigrateNothingRouter(), MigrateEverythingRouter()]):
+            self._test_create_model("test_mltdb_crmo4", should_run=False)
+        with override_settings(DATABASE_ROUTERS=[MigrateEverythingRouter(), MigrateNothingRouter()]):
+            self._test_create_model("test_mltdb_crmo4", should_run=True)
+
+    def _test_run_sql(self, app_label, should_run, hints=None):
+        with override_settings(DATABASE_ROUTERS=[MigrateEverythingRouter()]):
+            project_state = self.set_up_test_model(app_label)
+
+        sql = """
+        INSERT INTO {0}_pony (pink, weight) VALUES (1, 3.55);
+        INSERT INTO {0}_pony (pink, weight) VALUES (3, 5.0);
+        """.format(app_label)
+
+        operation = migrations.RunSQL(sql, hints=hints or {})
+        # Test the state alteration does nothing
+        new_state = project_state.clone()
+        operation.state_forwards(app_label, new_state)
+        self.assertEqual(new_state, project_state)
+        # Test the database alteration
+        self.assertEqual(project_state.apps.get_model(app_label, "Pony").objects.count(), 0)
+        with connection.schema_editor() as editor:
+            operation.database_forwards(app_label, editor, project_state, new_state)
+        Pony = project_state.apps.get_model(app_label, "Pony")
+        if should_run:
+            self.assertEqual(Pony.objects.count(), 2)
+        else:
+            self.assertEqual(Pony.objects.count(), 0)
+
+    @unittest.skipIf(sqlparse is None and connection.features.requires_sqlparse_for_splitting, "Missing sqlparse")
+    @override_settings(DATABASE_ROUTERS=[MigrateNothingRouter()])
+    def test_run_sql(self):
+        self._test_run_sql("test_mltdb_runsql", should_run=False)
+
+    @unittest.skipIf(sqlparse is None and connection.features.requires_sqlparse_for_splitting, "Missing sqlparse")
+    @override_settings(DATABASE_ROUTERS=[MigrateWhenFooRouter()])
+    def test_run_sql2(self):
+        self._test_run_sql("test_mltdb_runsql2", should_run=False)
+        self._test_run_sql("test_mltdb_runsql2", should_run=True, hints={'foo': True})
+
+    def _test_run_python(self, app_label, should_run, hints=None):
+        with override_settings(DATABASE_ROUTERS=[MigrateEverythingRouter()]):
+            project_state = self.set_up_test_model(app_label)
+
+        # Create the operation
+        def inner_method(models, schema_editor):
+            Pony = models.get_model(app_label, "Pony")
+            Pony.objects.create(pink=1, weight=3.55)
+            Pony.objects.create(weight=5)
+
+        operation = migrations.RunPython(inner_method, hints=hints or {})
+        # Test the state alteration does nothing
+        new_state = project_state.clone()
+        operation.state_forwards(app_label, new_state)
+        self.assertEqual(new_state, project_state)
+        # Test the database alteration
+        self.assertEqual(project_state.apps.get_model(app_label, "Pony").objects.count(), 0)
+        with connection.schema_editor() as editor:
+            operation.database_forwards(app_label, editor, project_state, new_state)
+        Pony = project_state.apps.get_model(app_label, "Pony")
+        if should_run:
+            self.assertEqual(Pony.objects.count(), 2)
+        else:
+            self.assertEqual(Pony.objects.count(), 0)
+
+    @override_settings(DATABASE_ROUTERS=[MigrateNothingRouter()])
+    def test_run_python(self):
+        self._test_run_python("test_mltdb_runpython", should_run=False)
+
+    @override_settings(DATABASE_ROUTERS=[MigrateWhenFooRouter()])
+    def test_run_python2(self):
+        self._test_run_python("test_mltdb_runpython2", should_run=False)
+        self._test_run_python("test_mltdb_runpython2", should_run=True, hints={'foo': True})

+ 0 - 38
tests/migrations/test_operations.py

@@ -1679,44 +1679,6 @@ class OperationTests(OperationTestBase):
         self.assertEqual(sorted(definition[2]), ["database_operations", "state_operations"])
 
 
-class MigrateNothingRouter(object):
-    """
-    A router that doesn't allow storing any model in any database.
-    """
-    def allow_migrate(self, db, model):
-        return False
-
-
-@override_settings(DATABASE_ROUTERS=[MigrateNothingRouter()])
-class MultiDBOperationTests(MigrationTestBase):
-    multi_db = True
-
-    def test_create_model(self):
-        """
-        Tests that CreateModel honours multi-db settings.
-        """
-        operation = migrations.CreateModel(
-            "Pony",
-            [
-                ("id", models.AutoField(primary_key=True)),
-                ("pink", models.IntegerField(default=1)),
-            ],
-        )
-        # Test the state alteration
-        project_state = ProjectState()
-        new_state = project_state.clone()
-        operation.state_forwards("test_crmo", new_state)
-        # Test the database alteration
-        self.assertTableNotExists("test_crmo_pony")
-        with connection.schema_editor() as editor:
-            operation.database_forwards("test_crmo", editor, project_state, new_state)
-        self.assertTableNotExists("test_crmo_pony")
-        # And test reversal
-        with connection.schema_editor() as editor:
-            operation.database_backwards("test_crmo", editor, new_state, project_state)
-        self.assertTableNotExists("test_crmo_pony")
-
-
 class SwappableOperationTests(OperationTestBase):
     """
     Tests that key operations ignore swappable models