Browse Source

Fixed #29522 -- Refactored the Deserializer functions to classes.

Co-authored-by: Emad Mokhtar <emad.mokhtar@veneficus.nl>
Amir Karimi 6 months ago
parent
commit
ee5147cfd7

+ 1 - 1
AUTHORS

@@ -68,7 +68,7 @@ answer newbie questions, and generally made Django that much better:
     Aljaž Košir <aljazkosir5@gmail.com>
     Aljosa Mohorovic <aljosa.mohorovic@gmail.com>
     Alokik Vijay <alokik.roe@gmail.com>
-    Amir Karimi <amk9978@gmail.com>
+    Amir Karimi <https://github.com/amk9978>
     Amit Chakradeo <https://amit.chakradeo.net/>
     Amit Ramon <amit.ramon@gmail.com>
     Amit Upadhyay <http://www.amitu.com/blog/>

+ 20 - 12
django/core/serializers/json.py

@@ -59,19 +59,27 @@ class Serializer(PythonSerializer):
         return super(PythonSerializer, self).getvalue()
 
 
-def Deserializer(stream_or_string, **options):
+class Deserializer(PythonDeserializer):
     """Deserialize a stream or string of JSON data."""
-    if not isinstance(stream_or_string, (bytes, str)):
-        stream_or_string = stream_or_string.read()
-    if isinstance(stream_or_string, bytes):
-        stream_or_string = stream_or_string.decode()
-    try:
-        objects = json.loads(stream_or_string)
-        yield from PythonDeserializer(objects, **options)
-    except (GeneratorExit, DeserializationError):
-        raise
-    except Exception as exc:
-        raise DeserializationError() from exc
+
+    def __init__(self, stream_or_string, **options):
+        if not isinstance(stream_or_string, (bytes, str)):
+            stream_or_string = stream_or_string.read()
+        if isinstance(stream_or_string, bytes):
+            stream_or_string = stream_or_string.decode()
+        try:
+            objects = json.loads(stream_or_string)
+        except Exception as exc:
+            raise DeserializationError() from exc
+        super().__init__(objects, **options)
+
+    def _handle_object(self, obj):
+        try:
+            yield from super()._handle_object(obj)
+        except (GeneratorExit, DeserializationError):
+            raise
+        except Exception as exc:
+            raise DeserializationError(f"Error deserializing object: {exc}") from exc
 
 
 class DjangoJSONEncoder(json.JSONEncoder):

+ 22 - 11
django/core/serializers/jsonl.py

@@ -39,19 +39,30 @@ class Serializer(PythonSerializer):
         return super(PythonSerializer, self).getvalue()
 
 
-def Deserializer(stream_or_string, **options):
+class Deserializer(PythonDeserializer):
     """Deserialize a stream or string of JSON data."""
-    if isinstance(stream_or_string, bytes):
-        stream_or_string = stream_or_string.decode()
-    if isinstance(stream_or_string, (bytes, str)):
-        stream_or_string = stream_or_string.split("\n")
-
-    for line in stream_or_string:
-        if not line.strip():
-            continue
+
+    def __init__(self, stream_or_string, **options):
+        if isinstance(stream_or_string, bytes):
+            stream_or_string = stream_or_string.decode()
+        if isinstance(stream_or_string, str):
+            stream_or_string = stream_or_string.splitlines()
+        super().__init__(Deserializer._get_lines(stream_or_string), **options)
+
+    def _handle_object(self, obj):
         try:
-            yield from PythonDeserializer([json.loads(line)], **options)
+            yield from super()._handle_object(obj)
         except (GeneratorExit, DeserializationError):
             raise
         except Exception as exc:
-            raise DeserializationError() from exc
+            raise DeserializationError(f"Error deserializing object: {exc}") from exc
+
+    @staticmethod
+    def _get_lines(stream):
+        for line in stream:
+            if not line.strip():
+                continue
+            try:
+                yield json.loads(line)
+            except Exception as exc:
+                raise DeserializationError() from exc

+ 71 - 48
django/core/serializers/python.py

@@ -96,45 +96,60 @@ class Serializer(base.Serializer):
         return self.objects
 
 
-def Deserializer(
-    object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False, **options
-):
+class Deserializer(base.Deserializer):
     """
     Deserialize simple Python objects back into Django ORM instances.
 
     It's expected that you pass the Python objects themselves (instead of a
     stream or a string) to the constructor
     """
-    handle_forward_references = options.pop("handle_forward_references", False)
-    field_names_cache = {}  # Model: <list of field_names>
 
-    for d in object_list:
+    def __init__(
+        self, object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False, **options
+    ):
+        super().__init__(object_list, **options)
+        self.handle_forward_references = options.pop("handle_forward_references", False)
+        self.using = using
+        self.ignorenonexistent = ignorenonexistent
+        self.field_names_cache = {}  # Model: <list of field_names>
+        self._iterator = None
+
+    def __iter__(self):
+        for obj in self.stream:
+            yield from self._handle_object(obj)
+
+    def __next__(self):
+        if self._iterator is None:
+            self._iterator = iter(self)
+        return next(self._iterator)
+
+    def _handle_object(self, obj):
+        data = {}
+        m2m_data = {}
+        deferred_fields = {}
+
         # Look up the model and starting build a dict of data for it.
         try:
-            Model = _get_model(d["model"])
+            Model = self._get_model_from_node(obj["model"])
         except base.DeserializationError:
-            if ignorenonexistent:
-                continue
-            else:
-                raise
-        data = {}
-        if "pk" in d:
+            if self.ignorenonexistent:
+                return
+            raise
+        if "pk" in obj:
             try:
-                data[Model._meta.pk.attname] = Model._meta.pk.to_python(d.get("pk"))
+                data[Model._meta.pk.attname] = Model._meta.pk.to_python(obj.get("pk"))
             except Exception as e:
                 raise base.DeserializationError.WithData(
-                    e, d["model"], d.get("pk"), None
+                    e, obj["model"], obj.get("pk"), None
                 )
-        m2m_data = {}
-        deferred_fields = {}
 
-        if Model not in field_names_cache:
-            field_names_cache[Model] = {f.name for f in Model._meta.get_fields()}
-        field_names = field_names_cache[Model]
+        if Model not in self.field_names_cache:
+            self.field_names_cache[Model] = {f.name for f in Model._meta.get_fields()}
+        field_names = self.field_names_cache[Model]
 
         # Handle each field
-        for field_name, field_value in d["fields"].items():
-            if ignorenonexistent and field_name not in field_names:
+        for field_name, field_value in obj["fields"].items():
+            if self.ignorenonexistent and field_name not in field_names:
                 # skip fields no longer on model
                 continue
 
@@ -145,51 +160,59 @@ def Deserializer(
                 field.remote_field, models.ManyToManyRel
             ):
                 try:
-                    values = base.deserialize_m2m_values(
-                        field, field_value, using, handle_forward_references
-                    )
+                    values = self._handle_m2m_field_node(field, field_value)
+                    if values == base.DEFER_FIELD:
+                        deferred_fields[field] = field_value
+                    else:
+                        m2m_data[field.name] = values
                 except base.M2MDeserializationError as e:
                     raise base.DeserializationError.WithData(
-                        e.original_exc, d["model"], d.get("pk"), e.pk
+                        e.original_exc, obj["model"], obj.get("pk"), e.pk
                     )
-                if values == base.DEFER_FIELD:
-                    deferred_fields[field] = field_value
-                else:
-                    m2m_data[field.name] = values
+
             # Handle FK fields
             elif field.remote_field and isinstance(
                 field.remote_field, models.ManyToOneRel
             ):
                 try:
-                    value = base.deserialize_fk_value(
-                        field, field_value, using, handle_forward_references
-                    )
+                    value = self._handle_fk_field_node(field, field_value)
+                    if value == base.DEFER_FIELD:
+                        deferred_fields[field] = field_value
+                    else:
+                        data[field.attname] = value
                 except Exception as e:
                     raise base.DeserializationError.WithData(
-                        e, d["model"], d.get("pk"), field_value
+                        e, obj["model"], obj.get("pk"), field_value
                     )
-                if value == base.DEFER_FIELD:
-                    deferred_fields[field] = field_value
-                else:
-                    data[field.attname] = value
+
             # Handle all other fields
             else:
                 try:
                     data[field.name] = field.to_python(field_value)
                 except Exception as e:
                     raise base.DeserializationError.WithData(
-                        e, d["model"], d.get("pk"), field_value
+                        e, obj["model"], obj.get("pk"), field_value
                     )
 
-        obj = base.build_instance(Model, data, using)
-        yield base.DeserializedObject(obj, m2m_data, deferred_fields)
+        model_instance = base.build_instance(Model, data, self.using)
+        yield base.DeserializedObject(model_instance, m2m_data, deferred_fields)
 
+    def _handle_m2m_field_node(self, field, field_value):
+        return base.deserialize_m2m_values(
+            field, field_value, self.using, self.handle_forward_references
+        )
 
-def _get_model(model_identifier):
-    """Look up a model from an "app_label.model_name" string."""
-    try:
-        return apps.get_model(model_identifier)
-    except (LookupError, TypeError):
-        raise base.DeserializationError(
-            "Invalid model identifier: '%s'" % model_identifier
+    def _handle_fk_field_node(self, field, field_value):
+        return base.deserialize_fk_value(
+            field, field_value, self.using, self.handle_forward_references
         )
+
+    @staticmethod
+    def _get_model_from_node(model_identifier):
+        """Look up a model from an "app_label.model_name" string."""
+        try:
+            return apps.get_model(model_identifier)
+        except (LookupError, TypeError):
+            raise base.DeserializationError(
+                f"Invalid model identifier: {model_identifier}"
+            )

+ 18 - 13
django/core/serializers/pyyaml.py

@@ -6,7 +6,6 @@ Requires PyYaml (https://pyyaml.org/), but that's checked for in __init__.
 
 import collections
 import decimal
-from io import StringIO
 
 import yaml
 
@@ -66,17 +65,23 @@ class Serializer(PythonSerializer):
         return super(PythonSerializer, self).getvalue()
 
 
-def Deserializer(stream_or_string, **options):
+class Deserializer(PythonDeserializer):
     """Deserialize a stream or string of YAML data."""
-    if isinstance(stream_or_string, bytes):
-        stream_or_string = stream_or_string.decode()
-    if isinstance(stream_or_string, str):
-        stream = StringIO(stream_or_string)
-    else:
+
+    def __init__(self, stream_or_string, **options):
         stream = stream_or_string
-    try:
-        yield from PythonDeserializer(yaml.load(stream, Loader=SafeLoader), **options)
-    except (GeneratorExit, DeserializationError):
-        raise
-    except Exception as exc:
-        raise DeserializationError() from exc
+        if isinstance(stream_or_string, bytes):
+            stream = stream_or_string.decode()
+        try:
+            objects = yaml.load(stream, Loader=SafeLoader)
+        except Exception as exc:
+            raise DeserializationError() from exc
+        super().__init__(objects, **options)
+
+    def _handle_object(self, obj):
+        try:
+            yield from super()._handle_object(obj)
+        except (GeneratorExit, DeserializationError):
+            raise
+        except Exception as exc:
+            raise DeserializationError(f"Error deserializing object: {exc}") from exc

+ 3 - 1
docs/releases/5.2.txt

@@ -241,7 +241,9 @@ Security
 Serialization
 ~~~~~~~~~~~~~
 
-* ...
+* Each serialization format now defines a ``Deserializer`` class, rather than a
+  function, to improve extensibility when defining a
+  :ref:`custom serialization format <custom-serialization-formats>`.
 
 Signals
 ~~~~~~~

+ 80 - 0
docs/topics/serialization.txt

@@ -347,6 +347,86 @@ again a mapping with the key being name of the field and the value the value:
 
 Referential fields are again represented by the PK or sequence of PKs.
 
+.. _custom-serialization-formats:
+
+Custom serialization formats
+----------------------------
+
+In addition to the default formats, you can create a custom serialization
+format.
+
+For example, let’s consider a csv serializer and deserializer. First, define a
+``Serializer`` and a ``Deserializer`` class. These can override existing
+serialization format classes:
+
+.. code-block:: python
+   :caption: ``path/to/custom_csv_serializer.py``
+
+    import csv
+
+    from django.apps import apps
+    from django.core import serializers
+    from django.core.serializers.base import DeserializationError
+
+
+    class Serializer(serializers.python.Serializer):
+        def get_dump_object(self, obj):
+            dumped_object = super().get_dump_object(obj)
+            row = [dumped_object["model"], str(dumped_object["pk"])]
+            row += [str(value) for value in dumped_object["fields"].values()]
+            return ",".join(row), dumped_object["model"]
+
+        def end_object(self, obj):
+            dumped_object_str, model = self.get_dump_object(obj)
+            if self.first:
+                fields = [field.name for field in apps.get_model(model)._meta.fields]
+                header = ",".join(fields)
+                self.stream.write(f"model,{header}\n")
+            self.stream.write(f"{dumped_object_str}\n")
+
+        def getvalue(self):
+            return super(serializers.python.Serializer, self).getvalue()
+
+
+    class Deserializer(serializers.python.Deserializer):
+        def __init__(self, stream_or_string, **options):
+            if isinstance(stream_or_string, bytes):
+                stream_or_string = stream_or_string.decode()
+            if isinstance(stream_or_string, str):
+                stream_or_string = stream_or_string.splitlines()
+            try:
+                objects = csv.DictReader(stream_or_string)
+            except Exception as exc:
+                raise DeserializationError() from exc
+            super().__init__(objects, **options)
+
+        def _handle_object(self, obj):
+            try:
+                model_fields = apps.get_model(obj["model"])._meta.fields
+                obj["fields"] = {
+                    field.name: obj[field.name]
+                    for field in model_fields
+                    if field.name in obj
+                }
+                yield from super()._handle_object(obj)
+            except (GeneratorExit, DeserializationError):
+                raise
+            except Exception as exc:
+                raise DeserializationError(f"Error deserializing object: {exc}") from exc
+
+Then add the module containing the serializer definitions to your
+:setting:`SERIALIZATION_MODULES` setting::
+
+    SERIALIZATION_MODULES = {
+        "csv": "path.to.custom_csv_serializer",
+        "json": "django.core.serializers.json",
+    }
+
+.. versionchanged:: 5.2
+
+    A ``Deserializer`` class definition was added to each of the provided
+    serialization formats.
+
 .. _topics-serialization-natural-keys:
 
 Natural keys

+ 125 - 0
tests/serializers/test_deserialization.py

@@ -0,0 +1,125 @@
+import json
+
+from django.core.serializers.base import DeserializationError, DeserializedObject
+from django.core.serializers.json import Deserializer as JsonDeserializer
+from django.core.serializers.jsonl import Deserializer as JsonlDeserializer
+from django.core.serializers.python import Deserializer
+from django.core.serializers.pyyaml import Deserializer as YamlDeserializer
+from django.test import SimpleTestCase
+
+from .models import Author
+
+
+class TestDeserializer(SimpleTestCase):
+    def setUp(self):
+        self.object_list = [
+            {"pk": 1, "model": "serializers.author", "fields": {"name": "Jane"}},
+            {"pk": 2, "model": "serializers.author", "fields": {"name": "Joe"}},
+        ]
+        self.deserializer = Deserializer(self.object_list)
+        self.jane = Author(name="Jane", pk=1)
+        self.joe = Author(name="Joe", pk=2)
+
+    def test_deserialized_object_repr(self):
+        deserial_obj = DeserializedObject(obj=self.jane)
+        self.assertEqual(
+            repr(deserial_obj), "<DeserializedObject: serializers.Author(pk=1)>"
+        )
+
+    def test_next_functionality(self):
+        first_item = next(self.deserializer)
+
+        self.assertEqual(first_item.object, self.jane)
+
+        second_item = next(self.deserializer)
+        self.assertEqual(second_item.object, self.joe)
+
+        with self.assertRaises(StopIteration):
+            next(self.deserializer)
+
+    def test_invalid_model_identifier(self):
+        invalid_object_list = [
+            {"pk": 1, "model": "serializers.author2", "fields": {"name": "Jane"}}
+        ]
+        self.deserializer = Deserializer(invalid_object_list)
+        with self.assertRaises(DeserializationError):
+            next(self.deserializer)
+
+        deserializer = Deserializer(object_list=[])
+        with self.assertRaises(StopIteration):
+            next(deserializer)
+
+    def test_custom_deserializer(self):
+        class CustomDeserializer(Deserializer):
+            @staticmethod
+            def _get_model_from_node(model_identifier):
+                return Author
+
+        deserializer = CustomDeserializer(self.object_list)
+        result = next(iter(deserializer))
+        deserialized_object = result.object
+        self.assertEqual(
+            self.jane,
+            deserialized_object,
+        )
+
+    def test_empty_object_list(self):
+        deserializer = Deserializer(object_list=[])
+        with self.assertRaises(StopIteration):
+            next(deserializer)
+
+    def test_json_bytes_input(self):
+        test_string = json.dumps(self.object_list)
+        stream = test_string.encode("utf-8")
+        deserializer = JsonDeserializer(stream_or_string=stream)
+
+        first_item = next(deserializer)
+        second_item = next(deserializer)
+
+        self.assertEqual(first_item.object, self.jane)
+        self.assertEqual(second_item.object, self.joe)
+
+    def test_jsonl_bytes_input(self):
+        test_string = """
+        {"pk": 1, "model": "serializers.author", "fields": {"name": "Jane"}}
+        {"pk": 2, "model": "serializers.author", "fields": {"name": "Joe"}}
+        {"pk": 3, "model": "serializers.author", "fields": {"name": "John"}}
+        {"pk": 4, "model": "serializers.author", "fields": {"name": "Smith"}}"""
+        stream = test_string.encode("utf-8")
+        deserializer = JsonlDeserializer(stream_or_string=stream)
+
+        first_item = next(deserializer)
+        second_item = next(deserializer)
+
+        self.assertEqual(first_item.object, self.jane)
+        self.assertEqual(second_item.object, self.joe)
+
+    def test_yaml_bytes_input(self):
+        test_string = """- pk: 1
+  model: serializers.author
+  fields:
+    name: Jane
+
+- pk: 2
+  model: serializers.author
+  fields:
+    name: Joe
+
+- pk: 3
+  model: serializers.author
+  fields:
+    name: John
+
+- pk: 4
+  model: serializers.author
+  fields:
+    name: Smith
+"""
+        stream = test_string.encode("utf-8")
+        deserializer = YamlDeserializer(stream_or_string=stream)
+
+        first_item = next(deserializer)
+        second_item = next(deserializer)
+
+        self.assertEqual(first_item.object, self.jane)
+        self.assertEqual(second_item.object, self.joe)

+ 0 - 13
tests/serializers/test_deserializedobject.py

@@ -1,13 +0,0 @@
-from django.core.serializers.base import DeserializedObject
-from django.test import SimpleTestCase
-
-from .models import Author
-
-
-class TestDeserializedObjectTests(SimpleTestCase):
-    def test_repr(self):
-        author = Author(name="John", pk=1)
-        deserial_obj = DeserializedObject(obj=author)
-        self.assertEqual(
-            repr(deserial_obj), "<DeserializedObject: serializers.Author(pk=1)>"
-        )