Browse Source

Fixed #26291 -- Allowed loaddata to handle forward references in natural_key fixtures.

Peter Inglesby 6 years ago
parent
commit
312eb5cb11

+ 6 - 0
django/core/management/commands/loaddata.py

@@ -109,8 +109,11 @@ class Command(BaseCommand):
             return
 
         with connection.constraint_checks_disabled():
+            self.objs_with_deferred_fields = []
             for fixture_label in fixture_labels:
                 self.load_label(fixture_label)
+            for obj in self.objs_with_deferred_fields:
+                obj.save_deferred_fields(using=self.using)
 
         # Since we disabled constraint checks, we must manually check for
         # any invalid keys that might have been added
@@ -163,6 +166,7 @@ class Command(BaseCommand):
 
                 objects = serializers.deserialize(
                     ser_fmt, fixture, using=self.using, ignorenonexistent=self.ignore,
+                    handle_forward_references=True,
                 )
 
                 for obj in objects:
@@ -189,6 +193,8 @@ class Command(BaseCommand):
                                 'error_msg': e,
                             },)
                             raise
+                    if obj.deferred_fields:
+                        self.objs_with_deferred_fields.append(obj)
                 if objects and show_progress:
                     self.stdout.write('')  # add a newline after progress indicator
                 self.loaded_object_count += loaded_objects_in_fixture

+ 37 - 5
django/core/serializers/base.py

@@ -3,8 +3,11 @@ Module for abstract serializer/unserializer base classes.
 """
 from io import StringIO
 
+from django.core.exceptions import ObjectDoesNotExist
 from django.db import models
 
+DEFER_FIELD = object()
+
 
 class SerializerDoesNotExist(KeyError):
     """The requested serializer was not found."""
@@ -201,9 +204,10 @@ class DeserializedObject:
     (and not touch the many-to-many stuff.)
     """
 
-    def __init__(self, obj, m2m_data=None):
+    def __init__(self, obj, m2m_data=None, deferred_fields=None):
         self.object = obj
         self.m2m_data = m2m_data
+        self.deferred_fields = deferred_fields
 
     def __repr__(self):
         return "<%s: %s(pk=%s)>" % (
@@ -225,6 +229,25 @@ class DeserializedObject:
         # the m2m data twice.
         self.m2m_data = None
 
+    def save_deferred_fields(self, using=None):
+        self.m2m_data = {}
+        for field, field_value in self.deferred_fields.items():
+            opts = self.object._meta
+            label = opts.app_label + '.' + opts.model_name
+            if isinstance(field.remote_field, models.ManyToManyRel):
+                try:
+                    values = deserialize_m2m_values(field, field_value, using, handle_forward_references=False)
+                except M2MDeserializationError as e:
+                    raise DeserializationError.WithData(e.original_exc, label, self.object.pk, e.pk)
+                self.m2m_data[field.name] = values
+            elif isinstance(field.remote_field, models.ManyToOneRel):
+                try:
+                    value = deserialize_fk_value(field, field_value, using, handle_forward_references=False)
+                except Exception as e:
+                    raise DeserializationError.WithData(e, label, self.object.pk, field_value)
+                setattr(self.object, field.attname, value)
+        self.save()
+
 
 def build_instance(Model, data, db):
     """
@@ -244,7 +267,7 @@ def build_instance(Model, data, db):
     return obj
 
 
-def deserialize_m2m_values(field, field_value, using):
+def deserialize_m2m_values(field, field_value, using, handle_forward_references):
     model = field.remote_field.model
     if hasattr(model._default_manager, 'get_by_natural_key'):
         def m2m_convert(value):
@@ -262,10 +285,13 @@ def deserialize_m2m_values(field, field_value, using):
             values.append(m2m_convert(pk))
         return values
     except Exception as e:
-        raise M2MDeserializationError(e, pk)
+        if isinstance(e, ObjectDoesNotExist) and handle_forward_references:
+            return DEFER_FIELD
+        else:
+            raise M2MDeserializationError(e, pk)
 
 
-def deserialize_fk_value(field, field_value, using):
+def deserialize_fk_value(field, field_value, using, handle_forward_references):
     if field_value is None:
         return None
     model = field.remote_field.model
@@ -273,7 +299,13 @@ def deserialize_fk_value(field, field_value, using):
     field_name = field.remote_field.field_name
     if (hasattr(default_manager, 'get_by_natural_key') and
             hasattr(field_value, '__iter__') and not isinstance(field_value, str)):
-        obj = default_manager.db_manager(using).get_by_natural_key(*field_value)
+        try:
+            obj = default_manager.db_manager(using).get_by_natural_key(*field_value)
+        except ObjectDoesNotExist:
+            if handle_forward_references:
+                return DEFER_FIELD
+            else:
+                raise
         value = getattr(obj, field_name)
         # If this is a natural foreign key to an object that has a FK/O2O as
         # the foreign key, use the FK value.

+ 13 - 5
django/core/serializers/python.py

@@ -83,6 +83,7 @@ def Deserializer(object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False
     It's expected that you pass the Python objects themselves (instead of a
     stream or a string) to the constructor
     """
+    handle_forward_references = options.pop('handle_forward_references', False)
     field_names_cache = {}  # Model: <list of field_names>
 
     for d in object_list:
@@ -101,6 +102,7 @@ def Deserializer(object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False
             except Exception as e:
                 raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), None)
         m2m_data = {}
+        deferred_fields = {}
 
         if Model not in field_names_cache:
             field_names_cache[Model] = {f.name for f in Model._meta.get_fields()}
@@ -118,17 +120,23 @@ def Deserializer(object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False
             # Handle M2M relations
             if field.remote_field and isinstance(field.remote_field, models.ManyToManyRel):
                 try:
-                    values = base.deserialize_m2m_values(field, field_value, using)
+                    values = base.deserialize_m2m_values(field, field_value, using, handle_forward_references)
                 except base.M2MDeserializationError as e:
                     raise base.DeserializationError.WithData(e.original_exc, d['model'], d.get('pk'), e.pk)
-                m2m_data[field.name] = values
+                if values == base.DEFER_FIELD:
+                    deferred_fields[field] = field_value
+                else:
+                    m2m_data[field.name] = values
             # Handle FK fields
             elif field.remote_field and isinstance(field.remote_field, models.ManyToOneRel):
                 try:
-                    value = base.deserialize_fk_value(field, field_value, using)
+                    value = base.deserialize_fk_value(field, field_value, using, handle_forward_references)
                 except Exception as e:
                     raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), field_value)
-                data[field.attname] = value
+                if value == base.DEFER_FIELD:
+                    deferred_fields[field] = field_value
+                else:
+                    data[field.attname] = value
             # Handle all other fields
             else:
                 try:
@@ -137,7 +145,7 @@ def Deserializer(object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False
                     raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), field_value)
 
         obj = base.build_instance(Model, data, using)
-        yield base.DeserializedObject(obj, m2m_data)
+        yield base.DeserializedObject(obj, m2m_data, deferred_fields)
 
 
 def _get_model(model_identifier):

+ 41 - 5
django/core/serializers/xml_serializer.py

@@ -8,6 +8,7 @@ from xml.sax.expatreader import ExpatParser as _ExpatParser
 
 from django.apps import apps
 from django.conf import settings
+from django.core.exceptions import ObjectDoesNotExist
 from django.core.serializers import base
 from django.db import DEFAULT_DB_ALIAS, models
 from django.utils.xmlutils import (
@@ -151,6 +152,7 @@ class Deserializer(base.Deserializer):
 
     def __init__(self, stream_or_string, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False, **options):
         super().__init__(stream_or_string, **options)
+        self.handle_forward_references = options.pop('handle_forward_references', False)
         self.event_stream = pulldom.parse(self.stream, self._make_parser())
         self.db = using
         self.ignore = ignorenonexistent
@@ -181,6 +183,7 @@ class Deserializer(base.Deserializer):
         # Also start building a dict of m2m data (this is saved as
         # {m2m_accessor_attribute : [list_of_related_objects]})
         m2m_data = {}
+        deferred_fields = {}
 
         field_names = {f.name for f in Model._meta.get_fields()}
         # Deserialize each field.
@@ -200,9 +203,26 @@ class Deserializer(base.Deserializer):
 
             # As is usually the case, relation fields get the special treatment.
             if field.remote_field and isinstance(field.remote_field, models.ManyToManyRel):
-                m2m_data[field.name] = self._handle_m2m_field_node(field_node, field)
+                value = self._handle_m2m_field_node(field_node, field)
+                if value == base.DEFER_FIELD:
+                    deferred_fields[field] = [
+                        [
+                            getInnerText(nat_node).strip()
+                            for nat_node in obj_node.getElementsByTagName('natural')
+                        ]
+                        for obj_node in field_node.getElementsByTagName('object')
+                    ]
+                else:
+                    m2m_data[field.name] = value
             elif field.remote_field and isinstance(field.remote_field, models.ManyToOneRel):
-                data[field.attname] = self._handle_fk_field_node(field_node, field)
+                value = self._handle_fk_field_node(field_node, field)
+                if value == base.DEFER_FIELD:
+                    deferred_fields[field] = [
+                        getInnerText(k).strip()
+                        for k in field_node.getElementsByTagName('natural')
+                    ]
+                else:
+                    data[field.attname] = value
             else:
                 if field_node.getElementsByTagName('None'):
                     value = None
@@ -213,7 +233,7 @@ class Deserializer(base.Deserializer):
         obj = base.build_instance(Model, data, self.db)
 
         # Return a DeserializedObject so that the m2m data has a place to live.
-        return base.DeserializedObject(obj, m2m_data)
+        return base.DeserializedObject(obj, m2m_data, deferred_fields)
 
     def _handle_fk_field_node(self, node, field):
         """
@@ -229,7 +249,13 @@ class Deserializer(base.Deserializer):
                 if keys:
                     # If there are 'natural' subelements, it must be a natural key
                     field_value = [getInnerText(k).strip() for k in keys]
-                    obj = model._default_manager.db_manager(self.db).get_by_natural_key(*field_value)
+                    try:
+                        obj = model._default_manager.db_manager(self.db).get_by_natural_key(*field_value)
+                    except ObjectDoesNotExist:
+                        if self.handle_forward_references:
+                            return base.DEFER_FIELD
+                        else:
+                            raise
                     obj_pk = getattr(obj, field.remote_field.field_name)
                     # If this is a natural foreign key to an object that
                     # has a FK/O2O as the foreign key, use the FK value
@@ -264,7 +290,17 @@ class Deserializer(base.Deserializer):
         else:
             def m2m_convert(n):
                 return model._meta.pk.to_python(n.getAttribute('pk'))
-        return [m2m_convert(c) for c in node.getElementsByTagName("object")]
+        values = []
+        try:
+            for c in node.getElementsByTagName('object'):
+                values.append(m2m_convert(c))
+        except Exception as e:
+            if isinstance(e, ObjectDoesNotExist) and self.handle_forward_references:
+                return base.DEFER_FIELD
+            else:
+                raise base.M2MDeserializationError(e, c)
+        else:
+            return values
 
     def _get_model_from_node(self, node, attr):
         """

+ 4 - 1
docs/releases/2.2.txt

@@ -184,7 +184,10 @@ Requests and Responses
 Serialization
 ~~~~~~~~~~~~~
 
-* ...
+* You can now deserialize data using natural keys containing :ref:`forward
+  references <natural-keys-and-forward-references>` by passing
+  ``handle_forward_references=True`` to ``serializers.deserialize()``.
+  Additionally, :djadmin:`loaddata` handles forward references automatically.
 
 Signals
 ~~~~~~~

+ 58 - 7
docs/topics/serialization.txt

@@ -514,17 +514,68 @@ command line flags to generate natural keys.
     natural keys during serialization, but *not* be able to load those
     key values, just don't define the ``get_by_natural_key()`` method.
 
+.. _natural-keys-and-forward-references:
+
+Natural keys and forward references
+-----------------------------------
+
+.. versionadded:: 2.2
+
+Sometimes when you use :ref:`natural foreign keys
+<topics-serialization-natural-keys>` you'll need to deserialize data where
+an object has a foreign key referencing another object that hasn't yet been
+deserialized. This is called a "forward reference".
+
+For instance, suppose you have the following objects in your fixture::
+
+    ...
+    {
+        "model": "store.book",
+        "fields": {
+            "name": "Mostly Harmless",
+            "author": ["Douglas", "Adams"]
+        }
+    },
+    ...
+    {
+        "model": "store.person",
+        "fields": {
+            "first_name": "Douglas",
+            "last_name": "Adams"
+        }
+    },
+    ...
+
+In order to handle this situation, you need to pass
+``handle_forward_references=True`` to ``serializers.deserialize()``. This will
+set the ``deferred_fields`` attribute on the ``DeserializedObject`` instances.
+You'll need to keep track of ``DeserializedObject`` instances where this
+attribute isn't ``None`` and later call ``save_deferred_fields()`` on them.
+
+Typical usage looks like this::
+
+    objs_with_deferred_fields = []
+
+    for obj in serializers.deserialize('xml', data, handle_forward_references=True):
+        obj.save()
+        if obj.deferred_fields is not None:
+            objs_with_deferred_fields.append(obj)
+
+    for obj in objs_with_deferred_fields:
+        obj.save_deferred_fields()
+
+For this to work, the ``ForeignKey`` on the referencing model must have
+``null=True``.
+
 Dependencies during serialization
 ---------------------------------
 
-Since natural keys rely on database lookups to resolve references, it
-is important that the data exists before it is referenced. You can't make
-a "forward reference" with natural keys -- the data you're referencing
-must exist before you include a natural key reference to that data.
+It's often possible to avoid explicitly having to handle forward references by
+taking care with the ordering of objects within a fixture.
 
-To accommodate this limitation, calls to :djadmin:`dumpdata` that use
-the :option:`dumpdata --natural-foreign` option will serialize any model with a
-``natural_key()`` method before serializing standard primary key objects.
+To help with this, calls to :djadmin:`dumpdata` that use the :option:`dumpdata
+--natural-foreign` option will serialize any model with a ``natural_key()``
+method before serializing standard primary key objects.
 
 However, this may not always be enough. If your natural key refers to
 another object (by using a foreign key or natural key to another object

+ 20 - 0
tests/fixtures/fixtures/forward_reference_fk.json

@@ -0,0 +1,20 @@
+[
+  {
+    "model": "fixtures.naturalkeything",
+    "fields": {
+      "key": "t1",
+      "other_thing": [
+        "t2"
+      ]
+    }
+  },
+  {
+    "model": "fixtures.naturalkeything",
+    "fields": {
+      "key": "t2",
+      "other_thing": [
+        "t1"
+      ]
+    }
+  }
+]

+ 23 - 0
tests/fixtures/fixtures/forward_reference_m2m.json

@@ -0,0 +1,23 @@
+[
+  {
+    "model": "fixtures.naturalkeything",
+    "fields": {
+      "key": "t1",
+      "other_things": [
+        ["t2"], ["t3"]
+      ]
+    }
+  },
+  {
+    "model": "fixtures.naturalkeything",
+    "fields": {
+      "key": "t2"
+    }
+  },
+  {
+    "model": "fixtures.naturalkeything",
+    "fields": {
+      "key": "t3"
+    }
+  }
+]

+ 18 - 0
tests/fixtures/models.py

@@ -116,3 +116,21 @@ class Book(models.Model):
 
 class PrimaryKeyUUIDModel(models.Model):
     id = models.UUIDField(primary_key=True, default=uuid.uuid4)
+
+
+class NaturalKeyThing(models.Model):
+    key = models.CharField(max_length=100)
+    other_thing = models.ForeignKey('NaturalKeyThing', on_delete=models.CASCADE, null=True)
+    other_things = models.ManyToManyField('NaturalKeyThing', related_name='thing_m2m_set')
+
+    class Manager(models.Manager):
+        def get_by_natural_key(self, key):
+            return self.get(key=key)
+
+    objects = Manager()
+
+    def natural_key(self):
+        return (self.key,)
+
+    def __str__(self):
+        return self.key

+ 20 - 1
tests/fixtures/tests.py

@@ -17,7 +17,8 @@ from django.db import IntegrityError, connection
 from django.test import TestCase, TransactionTestCase, skipUnlessDBFeature
 
 from .models import (
-    Article, Category, PrimaryKeyUUIDModel, ProxySpy, Spy, Tag, Visa,
+    Article, Category, NaturalKeyThing, PrimaryKeyUUIDModel, ProxySpy, Spy,
+    Tag, Visa,
 )
 
 
@@ -780,3 +781,21 @@ class FixtureTransactionTests(DumpDataAssertMixin, TransactionTestCase):
             '<Article: Time to reform copyright>',
             '<Article: Poker has no place on ESPN>',
         ])
+
+
+class ForwardReferenceTests(TestCase):
+    def test_forward_reference_fk(self):
+        management.call_command('loaddata', 'forward_reference_fk.json', verbosity=0)
+        self.assertEqual(NaturalKeyThing.objects.count(), 2)
+        t1, t2 = NaturalKeyThing.objects.all()
+        self.assertEqual(t1.other_thing, t2)
+        self.assertEqual(t2.other_thing, t1)
+
+    def test_forward_reference_m2m(self):
+        management.call_command('loaddata', 'forward_reference_m2m.json', verbosity=0)
+        self.assertEqual(NaturalKeyThing.objects.count(), 3)
+        t1 = NaturalKeyThing.objects.get_by_natural_key('t1')
+        self.assertQuerysetEqual(
+            t1.other_things.order_by('key'),
+            ['<NaturalKeyThing: t2>', '<NaturalKeyThing: t3>']
+        )

+ 18 - 0
tests/serializers/models/natural.py

@@ -19,3 +19,21 @@ class NaturalKeyAnchor(models.Model):
 
 class FKDataNaturalKey(models.Model):
     data = models.ForeignKey(NaturalKeyAnchor, models.SET_NULL, null=True)
+
+
+class NaturalKeyThing(models.Model):
+    key = models.CharField(max_length=100)
+    other_thing = models.ForeignKey('NaturalKeyThing', on_delete=models.CASCADE, null=True)
+    other_things = models.ManyToManyField('NaturalKeyThing', related_name='thing_m2m_set')
+
+    class Manager(models.Manager):
+        def get_by_natural_key(self, key):
+            return self.get(key=key)
+
+    objects = Manager()
+
+    def natural_key(self):
+        return (self.key,)
+
+    def __str__(self):
+        return self.key

+ 94 - 1
tests/serializers/test_natural.py

@@ -2,7 +2,7 @@ from django.core import serializers
 from django.db import connection
 from django.test import TestCase
 
-from .models import Child, FKDataNaturalKey, NaturalKeyAnchor
+from .models import Child, FKDataNaturalKey, NaturalKeyAnchor, NaturalKeyThing
 from .tests import register_tests
 
 
@@ -93,7 +93,100 @@ def natural_pk_mti_test(self, format):
         self.assertEqual(child.child_data, child.parent_data)
 
 
+def forward_ref_fk_test(self, format):
+    t1 = NaturalKeyThing.objects.create(key='t1')
+    t2 = NaturalKeyThing.objects.create(key='t2', other_thing=t1)
+    t1.other_thing = t2
+    t1.save()
+    string_data = serializers.serialize(
+        format, [t1, t2], use_natural_primary_keys=True,
+        use_natural_foreign_keys=True,
+    )
+    NaturalKeyThing.objects.all().delete()
+    objs_with_deferred_fields = []
+    for obj in serializers.deserialize(format, string_data, handle_forward_references=True):
+        obj.save()
+        if obj.deferred_fields:
+            objs_with_deferred_fields.append(obj)
+    for obj in objs_with_deferred_fields:
+        obj.save_deferred_fields()
+    t1 = NaturalKeyThing.objects.get(key='t1')
+    t2 = NaturalKeyThing.objects.get(key='t2')
+    self.assertEqual(t1.other_thing, t2)
+    self.assertEqual(t2.other_thing, t1)
+
+
+def forward_ref_fk_with_error_test(self, format):
+    t1 = NaturalKeyThing.objects.create(key='t1')
+    t2 = NaturalKeyThing.objects.create(key='t2', other_thing=t1)
+    t1.other_thing = t2
+    t1.save()
+    string_data = serializers.serialize(
+        format, [t1], use_natural_primary_keys=True,
+        use_natural_foreign_keys=True,
+    )
+    NaturalKeyThing.objects.all().delete()
+    objs_with_deferred_fields = []
+    for obj in serializers.deserialize(format, string_data, handle_forward_references=True):
+        obj.save()
+        if obj.deferred_fields:
+            objs_with_deferred_fields.append(obj)
+    obj = objs_with_deferred_fields[0]
+    msg = 'NaturalKeyThing matching query does not exist'
+    with self.assertRaisesMessage(serializers.base.DeserializationError, msg):
+        obj.save_deferred_fields()
+
+
+def forward_ref_m2m_test(self, format):
+    t1 = NaturalKeyThing.objects.create(key='t1')
+    t2 = NaturalKeyThing.objects.create(key='t2')
+    t3 = NaturalKeyThing.objects.create(key='t3')
+    t1.other_things.set([t2, t3])
+    string_data = serializers.serialize(
+        format, [t1, t2, t3], use_natural_primary_keys=True,
+        use_natural_foreign_keys=True,
+    )
+    NaturalKeyThing.objects.all().delete()
+    objs_with_deferred_fields = []
+    for obj in serializers.deserialize(format, string_data, handle_forward_references=True):
+        obj.save()
+        if obj.deferred_fields:
+            objs_with_deferred_fields.append(obj)
+    for obj in objs_with_deferred_fields:
+        obj.save_deferred_fields()
+    t1 = NaturalKeyThing.objects.get(key='t1')
+    t2 = NaturalKeyThing.objects.get(key='t2')
+    t3 = NaturalKeyThing.objects.get(key='t3')
+    self.assertCountEqual(t1.other_things.all(), [t2, t3])
+
+
+def forward_ref_m2m_with_error_test(self, format):
+    t1 = NaturalKeyThing.objects.create(key='t1')
+    t2 = NaturalKeyThing.objects.create(key='t2')
+    t3 = NaturalKeyThing.objects.create(key='t3')
+    t1.other_things.set([t2, t3])
+    t1.save()
+    string_data = serializers.serialize(
+        format, [t1, t2], use_natural_primary_keys=True,
+        use_natural_foreign_keys=True,
+    )
+    NaturalKeyThing.objects.all().delete()
+    objs_with_deferred_fields = []
+    for obj in serializers.deserialize(format, string_data, handle_forward_references=True):
+        obj.save()
+        if obj.deferred_fields:
+            objs_with_deferred_fields.append(obj)
+    obj = objs_with_deferred_fields[0]
+    msg = 'NaturalKeyThing matching query does not exist'
+    with self.assertRaisesMessage(serializers.base.DeserializationError, msg):
+        obj.save_deferred_fields()
+
+
 # Dynamically register tests for each serializer
 register_tests(NaturalKeySerializerTests, 'test_%s_natural_key_serializer', natural_key_serializer_test)
 register_tests(NaturalKeySerializerTests, 'test_%s_serializer_natural_keys', natural_key_test)
 register_tests(NaturalKeySerializerTests, 'test_%s_serializer_natural_pks_mti', natural_pk_mti_test)
+register_tests(NaturalKeySerializerTests, 'test_%s_forward_references_fks', forward_ref_fk_test)
+register_tests(NaturalKeySerializerTests, 'test_%s_forward_references_fk_errors', forward_ref_fk_with_error_test)
+register_tests(NaturalKeySerializerTests, 'test_%s_forward_references_m2ms', forward_ref_m2m_test)
+register_tests(NaturalKeySerializerTests, 'test_%s_forward_references_m2m_errors', forward_ref_m2m_with_error_test)