Browse Source

Fixed #28275 -- Added more hooks to SchemaEditor._alter_field().

Florian Apolloner 7 years ago
parent
commit
823d73be3e

+ 66 - 56
django/db/backends/base/schema.py

@@ -407,14 +407,12 @@ class BaseDatabaseSchemaEditor:
         # Drop the default if we need to
         # (Django usually does not use in-database defaults)
         if not self.skip_default(field) and self.effective_default(field) is not None:
+            changes_sql, params = self._alter_column_default_sql(model, None, field, drop=True)
             sql = self.sql_alter_column % {
                 "table": self.quote_name(model._meta.db_table),
-                "changes": self.sql_alter_column_no_default % {
-                    "column": self.quote_name(field.column),
-                    "type": db_params['type'],
-                }
+                "changes": changes_sql,
             }
-            self.execute(sql)
+            self.execute(sql, params)
         # Add an index, if required
         self.deferred_sql.extend(self._field_indexes_sql(model, field))
         # Add any FK constraints later
@@ -573,9 +571,7 @@ class BaseDatabaseSchemaEditor:
         post_actions = []
         # Type change?
         if old_type != new_type:
-            fragment, other_actions = self._alter_column_type_sql(
-                model._meta.db_table, old_field, new_field, new_type
-            )
+            fragment, other_actions = self._alter_column_type_sql(model, old_field, new_field, new_type)
             actions.append(fragment)
             post_actions.extend(other_actions)
         # When changing a column NULL constraint to NOT NULL with a given
@@ -595,49 +591,12 @@ class BaseDatabaseSchemaEditor:
             not self.skip_default(new_field)
         )
         if needs_database_default:
-            if self.connection.features.requires_literal_defaults:
-                # Some databases can't take defaults as a parameter (oracle)
-                # If this is the case, the individual schema backend should
-                # implement prepare_default
-                actions.append((
-                    self.sql_alter_column_default % {
-                        "column": self.quote_name(new_field.column),
-                        "type": new_type,
-                        "default": self.prepare_default(new_default),
-                    },
-                    [],
-                ))
-            else:
-                actions.append((
-                    self.sql_alter_column_default % {
-                        "column": self.quote_name(new_field.column),
-                        "type": new_type,
-                        "default": "%s",
-                    },
-                    [new_default],
-                ))
+            actions.append(self._alter_column_default_sql(model, old_field, new_field))
         # Nullability change?
         if old_field.null != new_field.null:
-            if (self.connection.features.interprets_empty_strings_as_nulls and
-                    new_field.get_internal_type() in ("CharField", "TextField")):
-                # The field is nullable in the database anyway, leave it alone
-                pass
-            elif new_field.null:
-                null_actions.append((
-                    self.sql_alter_column_null % {
-                        "column": self.quote_name(new_field.column),
-                        "type": new_type,
-                    },
-                    [],
-                ))
-            else:
-                null_actions.append((
-                    self.sql_alter_column_not_null % {
-                        "column": self.quote_name(new_field.column),
-                        "type": new_type,
-                    },
-                    [],
-                ))
+            fragment = self._alter_column_null_sql(model, old_field, new_field)
+            if fragment:
+                null_actions.append(fragment)
         # Only if we have a default and there is a change from NULL to NOT NULL
         four_way_default_alteration = (
             new_field.has_default() and
@@ -726,7 +685,7 @@ class BaseDatabaseSchemaEditor:
             rel_db_params = new_rel.field.db_parameters(connection=self.connection)
             rel_type = rel_db_params['type']
             fragment, other_actions = self._alter_column_type_sql(
-                new_rel.related_model._meta.db_table, old_rel.field, new_rel.field, rel_type
+                new_rel.related_model, old_rel.field, new_rel.field, rel_type
             )
             self.execute(
                 self.sql_alter_column % {
@@ -760,19 +719,70 @@ class BaseDatabaseSchemaEditor:
         # Drop the default if we need to
         # (Django usually does not use in-database defaults)
         if needs_database_default:
+            changes_sql, params = self._alter_column_default_sql(model, old_field, new_field, drop=True)
             sql = self.sql_alter_column % {
                 "table": self.quote_name(model._meta.db_table),
-                "changes": self.sql_alter_column_no_default % {
-                    "column": self.quote_name(new_field.column),
-                    "type": new_type,
-                }
+                "changes": changes_sql,
             }
-            self.execute(sql)
+            self.execute(sql, params)
         # Reset connection if required
         if self.connection.features.connection_persists_old_columns:
             self.connection.close()
 
-    def _alter_column_type_sql(self, table, old_field, new_field, new_type):
+    def _alter_column_null_sql(self, model, old_field, new_field):
+        """
+        Hook to specialize column null alteration.
+
+        Return a (sql, params) fragment to set a column to null or non-null
+        as required by new_field, or None if no changes are required.
+        """
+        if (self.connection.features.interprets_empty_strings_as_nulls and
+                new_field.get_internal_type() in ("CharField", "TextField")):
+            # The field is nullable in the database anyway, leave it alone.
+            return
+        else:
+            new_db_params = new_field.db_parameters(connection=self.connection)
+            sql = self.sql_alter_column_null if new_field.null else self.sql_alter_column_not_null
+            return (
+                sql % {
+                    'column': self.quote_name(new_field.column),
+                    'type': new_db_params['type'],
+                },
+                [],
+            )
+
+    def _alter_column_default_sql(self, model, old_field, new_field, drop=False):
+        """
+        Hook to specialize column default alteration.
+
+        Return a (sql, params) fragment to add or drop (depending on the drop
+        argument) a default to new_field's column.
+        """
+        new_default = self.effective_default(new_field)
+        default = '%s'
+        params = [new_default]
+
+        if drop:
+            params = []
+        elif self.connection.features.requires_literal_defaults:
+            # Some databases (Oracle) can't take defaults as a parameter
+            # If this is the case, the SchemaEditor for that database should
+            # implement prepare_default().
+            default = self.prepare_default(new_default)
+            params = []
+
+        new_db_params = new_field.db_parameters(connection=self.connection)
+        sql = self.sql_alter_column_no_default if drop else self.sql_alter_column_default
+        return (
+            sql % {
+                'column': self.quote_name(new_field.column),
+                'type': new_db_params['type'],
+                'default': default,
+            },
+            params,
+        )
+
+    def _alter_column_type_sql(self, model, old_field, new_field, new_type):
         """
         Hook to specialize column type alteration for different backends,
         for cases when a creation type is different to an alteration type

+ 2 - 2
django/db/backends/mysql/schema.py

@@ -92,9 +92,9 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
             new_type += " NOT NULL"
         return new_type
 
-    def _alter_column_type_sql(self, table, old_field, new_field, new_type):
+    def _alter_column_type_sql(self, model, old_field, new_field, new_type):
         new_type = self._set_field_new_type_null_status(old_field, new_type)
-        return super()._alter_column_type_sql(table, old_field, new_field, new_type)
+        return super()._alter_column_type_sql(model, old_field, new_field, new_type)
 
     def _rename_field_sql(self, table, old_field, new_field, new_type):
         new_type = self._set_field_new_type_null_status(old_field, new_type)

+ 3 - 2
django/db/backends/postgresql/schema.py

@@ -52,8 +52,9 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
                 return self._create_index_sql(model, [field], suffix='_like', sql=self.sql_create_text_index)
         return None
 
-    def _alter_column_type_sql(self, table, old_field, new_field, new_type):
+    def _alter_column_type_sql(self, model, old_field, new_field, new_type):
         """Make ALTER TYPE with SERIAL make sense."""
+        table = model._meta.db_table
         if new_type.lower() in ("serial", "bigserial"):
             column = new_field.column
             sequence_name = "%s_%s_seq" % (table, column)
@@ -100,7 +101,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
                 ],
             )
         else:
-            return super()._alter_column_type_sql(table, old_field, new_field, new_type)
+            return super()._alter_column_type_sql(model, old_field, new_field, new_type)
 
     def _alter_field(self, model, old_field, new_field, old_type, new_type,
                      old_db_params, new_db_params, strict=False):

+ 3 - 0
docs/releases/2.0.txt

@@ -289,6 +289,9 @@ Database backend API
   attribute with the name of the database that your backend works with. Django
   may use it in various messages, such as in system checks.
 
+* The first argument of ``SchemaEditor._alter_column_type_sql()`` is now
+  ``model`` rather than ``table``.
+
 * To improve performance when streaming large result sets from the database,
   :meth:`.QuerySet.iterator` now fetches 2000 rows at a time instead of 100.
   The old behavior can be restored using the ``chunk_size`` parameter. For