Browse Source

Fixed #21391 -- Allow model signals to lazily reference their senders.

Simon Charette 11 years ago
parent
commit
eb38257e51

+ 29 - 1
django/core/management/validation.py

@@ -1,5 +1,6 @@
 import collections
 import collections
 import sys
 import sys
+import types
 
 
 from django.conf import settings
 from django.conf import settings
 from django.core.management.color import color_style
 from django.core.management.color import color_style
@@ -25,7 +26,7 @@ def get_validation_errors(outfile, app=None):
     validates all models of all installed apps. Writes errors, if any, to outfile.
     validates all models of all installed apps. Writes errors, if any, to outfile.
     Returns number of errors.
     Returns number of errors.
     """
     """
-    from django.db import models, connection
+    from django.db import connection, models
     from django.db.models.loading import get_app_errors
     from django.db.models.loading import get_app_errors
     from django.db.models.deletion import SET_NULL, SET_DEFAULT
     from django.db.models.deletion import SET_NULL, SET_DEFAULT
 
 
@@ -363,6 +364,8 @@ def get_validation_errors(outfile, app=None):
             for it in opts.index_together:
             for it in opts.index_together:
                 validate_local_fields(e, opts, "index_together", it)
                 validate_local_fields(e, opts, "index_together", it)
 
 
+    validate_model_signals(e)
+
     return len(e.errors)
     return len(e.errors)
 
 
 
 
@@ -382,3 +385,28 @@ def validate_local_fields(e, opts, field_name, fields):
                     e.add(opts, '"%s" refers to %s. ManyToManyFields are not supported in %s.' % (field_name, f.name, field_name))
                     e.add(opts, '"%s" refers to %s. ManyToManyFields are not supported in %s.' % (field_name, f.name, field_name))
                 if f not in opts.local_fields:
                 if f not in opts.local_fields:
                     e.add(opts, '"%s" refers to %s. This is not in the same model as the %s statement.' % (field_name, f.name, field_name))
                     e.add(opts, '"%s" refers to %s. This is not in the same model as the %s statement.' % (field_name, f.name, field_name))
+
+
+def validate_model_signals(e):
+    """Ensure lazily referenced model signals senders are installed."""
+    from django.db import models
+
+    for name in dir(models.signals):
+        obj = getattr(models.signals, name)
+        if isinstance(obj, models.signals.ModelSignal):
+            for reference, receivers in obj.unresolved_references.items():
+                for receiver, _, _ in receivers:
+                    # The receiver is either a function or an instance of class
+                    # defining a `__call__` method.
+                    if isinstance(receiver, types.FunctionType):
+                        description = "The `%s` function" % receiver.__name__
+                    else:
+                        description = "An instance of the `%s` class" % receiver.__class__.__name__
+                    e.add(
+                        receiver.__module__,
+                        "%s was connected to the `%s` signal "
+                        "with a lazy reference to the '%s' sender, "
+                        "which has not been installed." % (
+                            description, name, '.'.join(reference)
+                        )
+                    )

+ 59 - 9
django/db/models/signals.py

@@ -1,20 +1,70 @@
+from collections import defaultdict
+
+from django.db.models.loading import get_model
 from django.dispatch import Signal
 from django.dispatch import Signal
+from django.utils import six
+
 
 
 class_prepared = Signal(providing_args=["class"])
 class_prepared = Signal(providing_args=["class"])
 
 
-pre_init = Signal(providing_args=["instance", "args", "kwargs"], use_caching=True)
-post_init = Signal(providing_args=["instance"], use_caching=True)
 
 
-pre_save = Signal(providing_args=["instance", "raw", "using", "update_fields"],
-                 use_caching=True)
-post_save = Signal(providing_args=["instance", "raw", "created", "using", "update_fields"], use_caching=True)
+class ModelSignal(Signal):
+    """
+    Signal subclass that allows the sender to be lazily specified as a string
+    of the `app_label.ModelName` form.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super(ModelSignal, self).__init__(*args, **kwargs)
+        self.unresolved_references = defaultdict(list)
+        class_prepared.connect(self._resolve_references)
+
+    def _resolve_references(self, sender, **kwargs):
+        opts = sender._meta
+        reference = (opts.app_label, opts.object_name)
+        try:
+            receivers = self.unresolved_references.pop(reference)
+        except KeyError:
+            pass
+        else:
+            for receiver, weak, dispatch_uid in receivers:
+                super(ModelSignal, self).connect(
+                    receiver, sender=sender, weak=weak, dispatch_uid=dispatch_uid
+                )
+
+    def connect(self, receiver, sender=None, weak=True, dispatch_uid=None):
+        if isinstance(sender, six.string_types):
+            try:
+                app_label, object_name = sender.split('.')
+            except ValueError:
+                raise ValueError(
+                    "Specified sender must either be a model or a "
+                    "model name of the 'app_label.ModelName' form."
+                )
+            sender = get_model(app_label, object_name, only_installed=False)
+            if sender is None:
+                reference = (app_label, object_name)
+                self.unresolved_references[reference].append(
+                    (receiver, weak, dispatch_uid)
+                )
+                return
+        super(ModelSignal, self).connect(
+            receiver, sender=sender, weak=weak, dispatch_uid=dispatch_uid
+        )
 
 
-pre_delete = Signal(providing_args=["instance", "using"], use_caching=True)
-post_delete = Signal(providing_args=["instance", "using"], use_caching=True)
+pre_init = ModelSignal(providing_args=["instance", "args", "kwargs"], use_caching=True)
+post_init = ModelSignal(providing_args=["instance"], use_caching=True)
+
+pre_save = ModelSignal(providing_args=["instance", "raw", "using", "update_fields"],
+                       use_caching=True)
+post_save = ModelSignal(providing_args=["instance", "raw", "created", "using", "update_fields"], use_caching=True)
+
+pre_delete = ModelSignal(providing_args=["instance", "using"], use_caching=True)
+post_delete = ModelSignal(providing_args=["instance", "using"], use_caching=True)
+
+m2m_changed = ModelSignal(providing_args=["action", "instance", "reverse", "model", "pk_set", "using"], use_caching=True)
 
 
 pre_migrate = Signal(providing_args=["app", "create_models", "verbosity", "interactive", "db"])
 pre_migrate = Signal(providing_args=["app", "create_models", "verbosity", "interactive", "db"])
 pre_syncdb = pre_migrate
 pre_syncdb = pre_migrate
 post_migrate = Signal(providing_args=["class", "app", "created_models", "verbosity", "interactive", "db"])
 post_migrate = Signal(providing_args=["class", "app", "created_models", "verbosity", "interactive", "db"])
 post_syncdb = post_migrate
 post_syncdb = post_migrate
-
-m2m_changed = Signal(providing_args=["action", "instance", "reverse", "model", "pk_set", "using"], use_caching=True)

+ 9 - 1
docs/ref/signals.txt

@@ -22,7 +22,7 @@ Model signals
    :synopsis: Signals sent by the model system.
    :synopsis: Signals sent by the model system.
 
 
 The :mod:`django.db.models.signals` module defines a set of signals sent by the
 The :mod:`django.db.models.signals` module defines a set of signals sent by the
-module system.
+model system.
 
 
 .. warning::
 .. warning::
 
 
@@ -37,6 +37,14 @@ module system.
     so if your handler is a local function, it may be garbage collected.  To
     so if your handler is a local function, it may be garbage collected.  To
     prevent this, pass ``weak=False`` when you call the signal's :meth:`~django.dispatch.Signal.connect`.
     prevent this, pass ``weak=False`` when you call the signal's :meth:`~django.dispatch.Signal.connect`.
 
 
+.. versionadded:: 1.7
+
+    Model signals ``sender`` model can be lazily referenced when connecting a
+    receiver by specifying its full application label. For example, an
+    ``Answer`` model defined in the ``polls`` application could be referenced
+    as ``'polls.Answer'``. This sort of reference can be quite handy when
+    dealing with circular import dependencies and swappable models.
+
 pre_init
 pre_init
 --------
 --------
 
 

+ 5 - 1
docs/releases/1.7.txt

@@ -425,7 +425,7 @@ Models
 * Is it now possible to avoid creating a backward relation for
 * Is it now possible to avoid creating a backward relation for
   :class:`~django.db.models.OneToOneField` by setting its
   :class:`~django.db.models.OneToOneField` by setting its
   :attr:`~django.db.models.ForeignKey.related_name` to
   :attr:`~django.db.models.ForeignKey.related_name` to
-  `'+'` or ending it with `'+'`.
+  ``'+'`` or ending it with ``'+'``.
 
 
 * :class:`F expressions <django.db.models.F>` support the power operator
 * :class:`F expressions <django.db.models.F>` support the power operator
   (``**``).
   (``**``).
@@ -436,6 +436,10 @@ Signals
 * The ``enter`` argument was added to the
 * The ``enter`` argument was added to the
   :data:`~django.test.signals.setting_changed` signal.
   :data:`~django.test.signals.setting_changed` signal.
 
 
+* The model signals can be now be connected to using a ``str`` of the
+  ``'app_label.ModelName'`` form – just like related fields – to lazily
+  reference their senders.
+
 Templates
 Templates
 ^^^^^^^^^
 ^^^^^^^^^
 
 

+ 13 - 0
docs/topics/auth/customizing.txt

@@ -413,6 +413,19 @@ different User model.
         class Article(models.Model):
         class Article(models.Model):
             author = models.ForeignKey(settings.AUTH_USER_MODEL)
             author = models.ForeignKey(settings.AUTH_USER_MODEL)
 
 
+    .. versionadded:: 1.7
+
+        When connecting to signals sent by the User model, you should specify the
+        custom model using the :setting:`AUTH_USER_MODEL` setting. For example::
+
+            from django.conf import settings
+            from django.db.models.signals import post_save
+
+            def post_save_receiver(signal, sender, instance, **kwargs):
+                pass
+
+            post_save.connect(post_save_receiver, sender=settings.AUTH_USER_MODEL)
+
 Specifying a custom User model
 Specifying a custom User model
 ------------------------------
 ------------------------------
 
 

+ 33 - 1
tests/model_validation/tests.py

@@ -1,10 +1,22 @@
 from django.core import management
 from django.core import management
+from django.core.management.validation import (
+    ModelErrorCollection, validate_model_signals
+)
+from django.db.models.signals import post_init
 from django.test import TestCase
 from django.test import TestCase
 from django.utils import six
 from django.utils import six
 
 
 
 
-class ModelValidationTest(TestCase):
+class OnPostInit(object):
+    def __call__(self, **kwargs):
+        pass
+
+
+def on_post_init(**kwargs):
+    pass
 
 
+
+class ModelValidationTest(TestCase):
     def test_models_validate(self):
     def test_models_validate(self):
         # All our models should validate properly
         # All our models should validate properly
         # Validation Tests:
         # Validation Tests:
@@ -13,3 +25,23 @@ class ModelValidationTest(TestCase):
         #   * related_name='+' doesn't clash with another '+'
         #   * related_name='+' doesn't clash with another '+'
         #       See: https://code.djangoproject.com/ticket/21375
         #       See: https://code.djangoproject.com/ticket/21375
         management.call_command("validate", stdout=six.StringIO())
         management.call_command("validate", stdout=six.StringIO())
+
+    def test_model_signal(self):
+        unresolved_references = post_init.unresolved_references.copy()
+        post_init.connect(on_post_init, sender='missing-app.Model')
+        post_init.connect(OnPostInit(), sender='missing-app.Model')
+        e = ModelErrorCollection(six.StringIO())
+        validate_model_signals(e)
+        self.assertSetEqual(set(e.errors), {
+            ('model_validation.tests',
+                "The `on_post_init` function was connected to the `post_init` "
+                "signal with a lazy reference to the 'missing-app.Model' "
+                "sender, which has not been installed."
+            ),
+            ('model_validation.tests',
+                "An instance of the `OnPostInit` class was connected to "
+                "the `post_init` signal with a lazy reference to the "
+                "'missing-app.Model' sender, which has not been installed."
+            )
+        })
+        post_init.unresolved_references = unresolved_references

+ 49 - 2
tests/signals/tests.py

@@ -1,5 +1,6 @@
 from __future__ import unicode_literals
 from __future__ import unicode_literals
 
 
+from django.db import models
 from django.db.models import signals
 from django.db.models import signals
 from django.dispatch import receiver
 from django.dispatch import receiver
 from django.test import TestCase
 from django.test import TestCase
@@ -8,8 +9,7 @@ from django.utils import six
 from .models import Author, Book, Car, Person
 from .models import Author, Book, Car, Person
 
 
 
 
-class SignalTests(TestCase):
-
+class BaseSignalTest(TestCase):
     def setUp(self):
     def setUp(self):
         # Save up the number of connected signals so that we can check at the
         # Save up the number of connected signals so that we can check at the
         # end that all the signals we register get properly unregistered (#9989)
         # end that all the signals we register get properly unregistered (#9989)
@@ -30,6 +30,8 @@ class SignalTests(TestCase):
         )
         )
         self.assertEqual(self.pre_signals, post_signals)
         self.assertEqual(self.pre_signals, post_signals)
 
 
+
+class SignalTests(BaseSignalTest):
     def test_save_signals(self):
     def test_save_signals(self):
         data = []
         data = []
 
 
@@ -239,3 +241,48 @@ class SignalTests(TestCase):
         self.assertTrue(a._run)
         self.assertTrue(a._run)
         self.assertTrue(b._run)
         self.assertTrue(b._run)
         self.assertEqual(signals.post_save.receivers, [])
         self.assertEqual(signals.post_save.receivers, [])
+
+
+class LazyModelRefTest(BaseSignalTest):
+    def setUp(self):
+        super(LazyModelRefTest, self).setUp()
+        self.received = []
+
+    def receiver(self, **kwargs):
+        self.received.append(kwargs)
+
+    def test_invalid_sender_model_name(self):
+        with self.assertRaisesMessage(ValueError,
+                    "Specified sender must either be a model or a "
+                    "model name of the 'app_label.ModelName' form."):
+            signals.post_init.connect(self.receiver, sender='invalid')
+
+    def test_already_loaded_model(self):
+        signals.post_init.connect(
+            self.receiver, sender='signals.Book', weak=False
+        )
+        try:
+            instance = Book()
+            self.assertEqual(self.received, [{
+                'signal': signals.post_init,
+                'sender': Book,
+                'instance': instance
+            }])
+        finally:
+            signals.post_init.disconnect(self.receiver, sender=Book)
+
+    def test_not_loaded_model(self):
+        signals.post_init.connect(
+            self.receiver, sender='signals.Created', weak=False
+        )
+
+        try:
+            class Created(models.Model):
+                pass
+
+            instance = Created()
+            self.assertEqual(self.received, [{
+                'signal': signals.post_init, 'sender': Created, 'instance': instance
+            }])
+        finally:
+            signals.post_init.disconnect(self.receiver, sender=Created)