Przeglądaj źródła

Fixed #24093 -- Prevented MigrationWriter to write operation kwargs that are not explicitly deconstructed

Markus Holtermann 10 lat temu
rodzic
commit
862ea825b5

+ 39 - 25
django/db/migrations/writer.py

@@ -46,29 +46,14 @@ class OperationWriter(object):
         self.buff = []
 
     def serialize(self):
-        imports = set()
-        name, args, kwargs = self.operation.deconstruct()
-        argspec = inspect.getargspec(self.operation.__init__)
-        normalized_kwargs = inspect.getcallargs(self.operation.__init__, *args, **kwargs)
-
-        # See if this operation is in django.db.migrations. If it is,
-        # We can just use the fact we already have that imported,
-        # otherwise, we need to add an import for the operation class.
-        if getattr(migrations, name, None) == self.operation.__class__:
-            self.feed('migrations.%s(' % name)
-        else:
-            imports.add('import %s' % (self.operation.__class__.__module__))
-            self.feed('%s.%s(' % (self.operation.__class__.__module__, name))
 
-        self.indent()
-        for arg_name in argspec.args[1:]:
-            arg_value = normalized_kwargs[arg_name]
-            if (arg_name in self.operation.serialization_expand_args and
-                    isinstance(arg_value, (list, tuple, dict))):
-                if isinstance(arg_value, dict):
-                    self.feed('%s={' % arg_name)
+        def _write(_arg_name, _arg_value):
+            if (_arg_name in self.operation.serialization_expand_args and
+                    isinstance(_arg_value, (list, tuple, dict))):
+                if isinstance(_arg_value, dict):
+                    self.feed('%s={' % _arg_name)
                     self.indent()
-                    for key, value in arg_value.items():
+                    for key, value in _arg_value.items():
                         key_string, key_imports = MigrationWriter.serialize(key)
                         arg_string, arg_imports = MigrationWriter.serialize(value)
                         self.feed('%s: %s,' % (key_string, arg_string))
@@ -77,18 +62,47 @@ class OperationWriter(object):
                     self.unindent()
                     self.feed('},')
                 else:
-                    self.feed('%s=[' % arg_name)
+                    self.feed('%s=[' % _arg_name)
                     self.indent()
-                    for item in arg_value:
+                    for item in _arg_value:
                         arg_string, arg_imports = MigrationWriter.serialize(item)
                         self.feed('%s,' % arg_string)
                         imports.update(arg_imports)
                     self.unindent()
                     self.feed('],')
             else:
-                arg_string, arg_imports = MigrationWriter.serialize(arg_value)
-                self.feed('%s=%s,' % (arg_name, arg_string))
+                arg_string, arg_imports = MigrationWriter.serialize(_arg_value)
+                self.feed('%s=%s,' % (_arg_name, arg_string))
                 imports.update(arg_imports)
+
+        imports = set()
+        name, args, kwargs = self.operation.deconstruct()
+        argspec = inspect.getargspec(self.operation.__init__)
+
+        # See if this operation is in django.db.migrations. If it is,
+        # We can just use the fact we already have that imported,
+        # otherwise, we need to add an import for the operation class.
+        if getattr(migrations, name, None) == self.operation.__class__:
+            self.feed('migrations.%s(' % name)
+        else:
+            imports.add('import %s' % (self.operation.__class__.__module__))
+            self.feed('%s.%s(' % (self.operation.__class__.__module__, name))
+
+        self.indent()
+
+        # Start at one because argspec includes "self"
+        for i, arg in enumerate(args, 1):
+            arg_value = arg
+            arg_name = argspec.args[i]
+            _write(arg_name, arg_value)
+
+        i = len(args)
+        # Only iterate over remaining arguments
+        for arg_name in argspec.args[i + 1:]:
+            if arg_name in kwargs:
+                arg_value = kwargs[arg_name]
+                _write(arg_name, arg_value)
+
         self.unindent()
         self.feed('),')
         return self.render(), imports

+ 61 - 0
tests/custom_migration_operations/operations.py

@@ -31,3 +31,64 @@ class TestOperation(Operation):
 
 class CreateModel(TestOperation):
     pass
+
+
+class ArgsOperation(TestOperation):
+    def __init__(self, arg1, arg2):
+        self.arg1, self.arg2 = arg1, arg2
+
+    def deconstruct(self):
+        return (
+            self.__class__.__name__,
+            [self.arg1, self.arg2],
+            {}
+        )
+
+
+class KwargsOperation(TestOperation):
+    def __init__(self, kwarg1=None, kwarg2=None):
+        self.kwarg1, self.kwarg2 = kwarg1, kwarg2
+
+    def deconstruct(self):
+        kwargs = {}
+        if self.kwarg1 is not None:
+            kwargs['kwarg1'] = self.kwarg1
+        if self.kwarg2 is not None:
+            kwargs['kwarg2'] = self.kwarg2
+        return (
+            self.__class__.__name__,
+            [],
+            kwargs
+        )
+
+
+class ArgsKwargsOperation(TestOperation):
+    def __init__(self, arg1, arg2, kwarg1=None, kwarg2=None):
+        self.arg1, self.arg2 = arg1, arg2
+        self.kwarg1, self.kwarg2 = kwarg1, kwarg2
+
+    def deconstruct(self):
+        kwargs = {}
+        if self.kwarg1 is not None:
+            kwargs['kwarg1'] = self.kwarg1
+        if self.kwarg2 is not None:
+            kwargs['kwarg2'] = self.kwarg2
+        return (
+            self.__class__.__name__,
+            [self.arg1, self.arg2],
+            kwargs,
+        )
+
+
+class ExpandArgsOperation(TestOperation):
+    serialization_expand_args = ['arg']
+
+    def __init__(self, arg):
+        self.arg = arg
+
+    def deconstruct(self):
+        return (
+            self.__class__.__name__,
+            [self.arg],
+            {}
+        )

+ 75 - 2
tests/migrations/test_writer.py

@@ -10,8 +10,8 @@ import unittest
 
 from django.core.validators import RegexValidator, EmailValidator
 from django.db import models, migrations
-from django.db.migrations.writer import MigrationWriter, SettingsReference
-from django.test import TestCase, ignore_warnings
+from django.db.migrations.writer import MigrationWriter, OperationWriter, SettingsReference
+from django.test import SimpleTestCase, TestCase, ignore_warnings
 from django.conf import settings
 from django.utils import datetime_safe, six
 from django.utils.deconstruct import deconstructible
@@ -30,6 +30,79 @@ class TestModel1(object):
     thing = models.FileField(upload_to=upload_to)
 
 
+class OperationWriterTests(SimpleTestCase):
+
+    def test_empty_signature(self):
+        operation = custom_migration_operations.operations.TestOperation()
+        writer = OperationWriter(operation)
+        writer.indentation = 0
+        buff, imports = writer.serialize()
+        self.assertEqual(imports, {'import custom_migration_operations.operations'})
+        self.assertEqual(
+            buff,
+            'custom_migration_operations.operations.TestOperation(\n'
+            '),'
+        )
+
+    def test_args_signature(self):
+        operation = custom_migration_operations.operations.ArgsOperation(1, 2)
+        writer = OperationWriter(operation)
+        writer.indentation = 0
+        buff, imports = writer.serialize()
+        self.assertEqual(imports, {'import custom_migration_operations.operations'})
+        self.assertEqual(
+            buff,
+            'custom_migration_operations.operations.ArgsOperation(\n'
+            '    arg1=1,\n'
+            '    arg2=2,\n'
+            '),'
+        )
+
+    def test_kwargs_signature(self):
+        operation = custom_migration_operations.operations.KwargsOperation(kwarg1=1)
+        writer = OperationWriter(operation)
+        writer.indentation = 0
+        buff, imports = writer.serialize()
+        self.assertEqual(imports, {'import custom_migration_operations.operations'})
+        self.assertEqual(
+            buff,
+            'custom_migration_operations.operations.KwargsOperation(\n'
+            '    kwarg1=1,\n'
+            '),'
+        )
+
+    def test_args_kwargs_signature(self):
+        operation = custom_migration_operations.operations.ArgsKwargsOperation(1, 2, kwarg2=4)
+        writer = OperationWriter(operation)
+        writer.indentation = 0
+        buff, imports = writer.serialize()
+        self.assertEqual(imports, {'import custom_migration_operations.operations'})
+        self.assertEqual(
+            buff,
+            'custom_migration_operations.operations.ArgsKwargsOperation(\n'
+            '    arg1=1,\n'
+            '    arg2=2,\n'
+            '    kwarg2=4,\n'
+            '),'
+        )
+
+    def test_expand_args_signature(self):
+        operation = custom_migration_operations.operations.ExpandArgsOperation([1, 2])
+        writer = OperationWriter(operation)
+        writer.indentation = 0
+        buff, imports = writer.serialize()
+        self.assertEqual(imports, {'import custom_migration_operations.operations'})
+        self.assertEqual(
+            buff,
+            'custom_migration_operations.operations.ExpandArgsOperation(\n'
+            '    arg=[\n'
+            '        1,\n'
+            '        2,\n'
+            '    ],\n'
+            '),'
+        )
+
+
 class WriterTests(TestCase):
     """
     Tests the migration writer (makes migration files from Migration instances)