Преглед изворни кода

Fixed #24509 -- Added Expression support to SQLInsertCompiler

Alex Hill пре 9 година
родитељ
комит
134ca4d438

+ 5 - 6
django/contrib/gis/db/backends/oracle/operations.py

@@ -263,11 +263,10 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
         from django.contrib.gis.db.backends.oracle.models import OracleSpatialRefSys
         return OracleSpatialRefSys
 
-    def modify_insert_params(self, placeholders, params):
+    def modify_insert_params(self, placeholder, params):
         """Drop out insert parameters for NULL placeholder. Needed for Oracle Spatial
-        backend due to #10888
+        backend due to #10888.
         """
-        # This code doesn't work for bulk insert cases.
-        assert len(placeholders) == 1
-        return [[param for pholder, param
-                 in six.moves.zip(placeholders[0], params[0]) if pholder != 'NULL'], ]
+        if placeholder == 'NULL':
+            return []
+        return super(OracleOperations, self).modify_insert_params(placeholder, params)

+ 1 - 1
django/db/backends/base/operations.py

@@ -576,7 +576,7 @@ class BaseDatabaseOperations(object):
     def combine_duration_expression(self, connector, sub_expressions):
         return self.combine_expression(connector, sub_expressions)
 
-    def modify_insert_params(self, placeholders, params):
+    def modify_insert_params(self, placeholder, params):
         """Allow modification of insert parameters. Needed for Oracle Spatial
         backend due to #10888.
         """

+ 4 - 3
django/db/backends/mysql/operations.py

@@ -166,9 +166,10 @@ class DatabaseOperations(BaseDatabaseOperations):
     def max_name_length(self):
         return 64
 
-    def bulk_insert_sql(self, fields, num_values):
-        items_sql = "(%s)" % ", ".join(["%s"] * len(fields))
-        return "VALUES " + ", ".join([items_sql] * num_values)
+    def bulk_insert_sql(self, fields, placeholder_rows):
+        placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
+        values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
+        return "VALUES " + values_sql
 
     def combine_expression(self, connector, sub_expressions):
         """

+ 5 - 3
django/db/backends/oracle/operations.py

@@ -439,6 +439,8 @@ WHEN (new.%(col_name)s IS NULL)
         name_length = self.max_name_length() - 3
         return '%s_TR' % truncate_name(table, name_length).upper()
 
-    def bulk_insert_sql(self, fields, num_values):
-        items_sql = "SELECT %s FROM DUAL" % ", ".join(["%s"] * len(fields))
-        return " UNION ALL ".join([items_sql] * num_values)
+    def bulk_insert_sql(self, fields, placeholder_rows):
+        return " UNION ALL ".join(
+            "SELECT %s FROM DUAL" % ", ".join(row)
+            for row in placeholder_rows
+        )

+ 4 - 3
django/db/backends/postgresql/operations.py

@@ -221,9 +221,10 @@ class DatabaseOperations(BaseDatabaseOperations):
     def return_insert_id(self):
         return "RETURNING %s", ()
 
-    def bulk_insert_sql(self, fields, num_values):
-        items_sql = "(%s)" % ", ".join(["%s"] * len(fields))
-        return "VALUES " + ", ".join([items_sql] * num_values)
+    def bulk_insert_sql(self, fields, placeholder_rows):
+        placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
+        values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
+        return "VALUES " + values_sql
 
     def adapt_datefield_value(self, value):
         return value

+ 5 - 7
django/db/backends/sqlite3/operations.py

@@ -226,13 +226,11 @@ class DatabaseOperations(BaseDatabaseOperations):
             value = uuid.UUID(value)
         return value
 
-    def bulk_insert_sql(self, fields, num_values):
-        res = []
-        res.append("SELECT %s" % ", ".join(
-            "%%s AS %s" % self.quote_name(f.column) for f in fields
-        ))
-        res.extend(["UNION ALL SELECT %s" % ", ".join(["%s"] * len(fields))] * (num_values - 1))
-        return " ".join(res)
+    def bulk_insert_sql(self, fields, placeholder_rows):
+        return " UNION ALL ".join(
+            "SELECT %s" % ", ".join(row)
+            for row in placeholder_rows
+        )
 
     def combine_expression(self, connector, sub_expressions):
         # SQLite doesn't have a power function, so we fake it with a

+ 21 - 0
django/db/models/expressions.py

@@ -180,6 +180,13 @@ class BaseExpression(object):
                 return True
         return False
 
+    @cached_property
+    def contains_column_references(self):
+        for expr in self.get_source_expressions():
+            if expr and expr.contains_column_references:
+                return True
+        return False
+
     def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
         """
         Provides the chance to do any preprocessing or validation before being
@@ -339,6 +346,17 @@ class BaseExpression(object):
     def reverse_ordering(self):
         return self
 
+    def flatten(self):
+        """
+        Recursively yield this expression and all subexpressions, in
+        depth-first order.
+        """
+        yield self
+        for expr in self.get_source_expressions():
+            if expr:
+                for inner_expr in expr.flatten():
+                    yield inner_expr
+
 
 class Expression(BaseExpression, Combinable):
     """
@@ -613,6 +631,9 @@ class Random(Expression):
 
 
 class Col(Expression):
+
+    contains_column_references = True
+
     def __init__(self, alias, target, output_field=None):
         if output_field is None:
             output_field = target

+ 2 - 0
django/db/models/query.py

@@ -458,6 +458,8 @@ class QuerySet(object):
         specifying whether an object was created.
         """
         lookup, params = self._extract_model_params(defaults, **kwargs)
+        # The get() needs to be targeted at the write database in order
+        # to avoid potential transaction consistency problems.
         self._for_write = True
         try:
             return self.get(**lookup), False

+ 112 - 33
django/db/models/sql/compiler.py

@@ -909,17 +909,102 @@ class SQLInsertCompiler(SQLCompiler):
         self.return_id = False
         super(SQLInsertCompiler, self).__init__(*args, **kwargs)
 
-    def placeholder(self, field, val):
+    def field_as_sql(self, field, val):
+        """
+        Take a field and a value intended to be saved on that field, and
+        return placeholder SQL and accompanying params. Checks for raw values,
+        expressions and fields with get_placeholder() defined in that order.
+
+        When field is None, the value is considered raw and is used as the
+        placeholder, with no corresponding parameters returned.
+        """
         if field is None:
             # A field value of None means the value is raw.
-            return val
+            sql, params = val, []
+        elif hasattr(val, 'as_sql'):
+            # This is an expression, let's compile it.
+            sql, params = self.compile(val)
         elif hasattr(field, 'get_placeholder'):
             # Some fields (e.g. geo fields) need special munging before
             # they can be inserted.
-            return field.get_placeholder(val, self, self.connection)
+            sql, params = field.get_placeholder(val, self, self.connection), [val]
         else:
             # Return the common case for the placeholder
-            return '%s'
+            sql, params = '%s', [val]
+
+        # The following hook is only used by Oracle Spatial, which sometimes
+        # needs to yield 'NULL' and [] as its placeholder and params instead
+        # of '%s' and [None]. The 'NULL' placeholder is produced earlier by
+        # OracleOperations.get_geom_placeholder(). The following line removes
+        # the corresponding None parameter. See ticket #10888.
+        params = self.connection.ops.modify_insert_params(sql, params)
+
+        return sql, params
+
+    def prepare_value(self, field, value):
+        """
+        Prepare a value to be used in a query by resolving it if it is an
+        expression and otherwise calling the field's get_db_prep_save().
+        """
+        if hasattr(value, 'resolve_expression'):
+            value = value.resolve_expression(self.query, allow_joins=False, for_save=True)
+            # Don't allow values containing Col expressions. They refer to
+            # existing columns on a row, but in the case of insert the row
+            # doesn't exist yet.
+            if value.contains_column_references:
+                raise ValueError(
+                    'Failed to insert expression "%s" on %s. F() expressions '
+                    'can only be used to update, not to insert.' % (value, field)
+                )
+            if value.contains_aggregate:
+                raise FieldError("Aggregate functions are not allowed in this query")
+        else:
+            value = field.get_db_prep_save(value, connection=self.connection)
+        return value
+
+    def pre_save_val(self, field, obj):
+        """
+        Get the given field's value off the given obj. pre_save() is used for
+        things like auto_now on DateTimeField. Skip it if this is a raw query.
+        """
+        if self.query.raw:
+            return getattr(obj, field.attname)
+        return field.pre_save(obj, add=True)
+
+    def assemble_as_sql(self, fields, value_rows):
+        """
+        Take a sequence of N fields and a sequence of M rows of values,
+        generate placeholder SQL and parameters for each field and value, and
+        return a pair containing:
+         * a sequence of M rows of N SQL placeholder strings, and
+         * a sequence of M rows of corresponding parameter values.
+
+        Each placeholder string may contain any number of '%s' interpolation
+        strings, and each parameter row will contain exactly as many params
+        as the total number of '%s's in the corresponding placeholder row.
+        """
+        if not value_rows:
+            return [], []
+
+        # list of (sql, [params]) tuples for each object to be saved
+        # Shape: [n_objs][n_fields][2]
+        rows_of_fields_as_sql = (
+            (self.field_as_sql(field, v) for field, v in zip(fields, row))
+            for row in value_rows
+        )
+
+        # tuple like ([sqls], [[params]s]) for each object to be saved
+        # Shape: [n_objs][2][n_fields]
+        sql_and_param_pair_rows = (zip(*row) for row in rows_of_fields_as_sql)
+
+        # Extract separate lists for placeholders and params.
+        # Each of these has shape [n_objs][n_fields]
+        placeholder_rows, param_rows = zip(*sql_and_param_pair_rows)
+
+        # Params for each field are still lists, and need to be flattened.
+        param_rows = [[p for ps in row for p in ps] for row in param_rows]
+
+        return placeholder_rows, param_rows
 
     def as_sql(self):
         # We don't need quote_name_unless_alias() here, since these are all
@@ -933,35 +1018,27 @@ class SQLInsertCompiler(SQLCompiler):
         result.append('(%s)' % ', '.join(qn(f.column) for f in fields))
 
         if has_fields:
-            params = values = [
-                [
-                    f.get_db_prep_save(
-                        getattr(obj, f.attname) if self.query.raw else f.pre_save(obj, True),
-                        connection=self.connection
-                    ) for f in fields
-                ]
+            value_rows = [
+                [self.prepare_value(field, self.pre_save_val(field, obj)) for field in fields]
                 for obj in self.query.objs
             ]
         else:
-            values = [[self.connection.ops.pk_default_value()] for obj in self.query.objs]
-            params = [[]]
+            # An empty object.
+            value_rows = [[self.connection.ops.pk_default_value()] for _ in self.query.objs]
             fields = [None]
-        can_bulk = (not any(hasattr(field, "get_placeholder") for field in fields) and
-            not self.return_id and self.connection.features.has_bulk_insert)
 
-        if can_bulk:
-            placeholders = [["%s"] * len(fields)]
-        else:
-            placeholders = [
-                [self.placeholder(field, v) for field, v in zip(fields, val)]
-                for val in values
-            ]
-            # Oracle Spatial needs to remove some values due to #10888
-            params = self.connection.ops.modify_insert_params(placeholders, params)
+        # Currently the backends just accept values when generating bulk
+        # queries and generate their own placeholders. Doing that isn't
+        # necessary and it should be possible to use placeholders and
+        # expressions in bulk inserts too.
+        can_bulk = (not self.return_id and self.connection.features.has_bulk_insert)
+
+        placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
+
         if self.return_id and self.connection.features.can_return_id_from_insert:
-            params = params[0]
+            params = param_rows[0]
             col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
-            result.append("VALUES (%s)" % ", ".join(placeholders[0]))
+            result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
             r_fmt, r_params = self.connection.ops.return_insert_id()
             # Skip empty r_fmt to allow subclasses to customize behavior for
             # 3rd party backends. Refs #19096.
@@ -969,13 +1046,14 @@ class SQLInsertCompiler(SQLCompiler):
                 result.append(r_fmt % col)
                 params += r_params
             return [(" ".join(result), tuple(params))]
+
         if can_bulk:
-            result.append(self.connection.ops.bulk_insert_sql(fields, len(values)))
-            return [(" ".join(result), tuple(v for val in values for v in val))]
+            result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
+            return [(" ".join(result), tuple(p for ps in param_rows for p in ps))]
         else:
             return [
                 (" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals)
-                for p, vals in zip(placeholders, params)
+                for p, vals in zip(placeholder_rows, param_rows)
             ]
 
     def execute_sql(self, return_id=False):
@@ -1034,10 +1112,11 @@ class SQLUpdateCompiler(SQLCompiler):
                         connection=self.connection,
                     )
                 else:
-                    raise TypeError("Database is trying to update a relational field "
-                                    "of type %s with a value of type %s. Make sure "
-                                    "you are setting the correct relations" %
-                                    (field.__class__.__name__, val.__class__.__name__))
+                    raise TypeError(
+                        "Tried to update field %s with a model instance, %r. "
+                        "Use a value compatible with %s."
+                        % (field, val, field.__class__.__name__)
+                    )
             else:
                 val = field.get_db_prep_save(val, connection=self.connection)
 

+ 3 - 3
django/db/models/sql/subqueries.py

@@ -139,9 +139,9 @@ class UpdateQuery(Query):
 
     def add_update_fields(self, values_seq):
         """
-        Turn a sequence of (field, model, value) triples into an update query.
-        Used by add_update_values() as well as the "fast" update path when
-        saving models.
+        Append a sequence of (field, model, value) triples to the internal list
+        that will be used to generate the UPDATE query. Might be more usefully
+        called add_update_targets() to hint at the extra information here.
         """
         self.values.extend(values_seq)
 

+ 22 - 6
docs/ref/models/expressions.txt

@@ -5,10 +5,14 @@ Query Expressions
 .. currentmodule:: django.db.models
 
 Query expressions describe a value or a computation that can be used as part of
-a filter, order by, annotation, or aggregate. There are a number of built-in
-expressions (documented below) that can be used to help you write queries.
-Expressions can be combined, or in some cases nested, to form more complex
-computations.
+an update, create, filter, order by, annotation, or aggregate. There are a
+number of built-in expressions (documented below) that can be used to help you
+write queries. Expressions can be combined, or in some cases nested, to form
+more complex computations.
+
+.. versionchanged:: 1.9
+
+    Support for using expressions when creating new model instances was added.
 
 Supported arithmetic
 ====================
@@ -27,7 +31,7 @@ Some examples
 .. code-block:: python
 
     from django.db.models import F, Count
-    from django.db.models.functions import Length
+    from django.db.models.functions import Length, Upper, Value
 
     # Find companies that have more employees than chairs.
     Company.objects.filter(num_employees__gt=F('num_chairs'))
@@ -49,6 +53,13 @@ Some examples
     >>> company.chairs_needed
     70
 
+    # Create a new company using expressions.
+    >>> company = Company.objects.create(name='Google', ticker=Upper(Value('goog')))
+    # Be sure to refresh it if you need to access the field.
+    >>> company.refresh_from_db()
+    >>> company.ticker
+    'GOOG'
+
     # Annotate models with an aggregated value. Both forms
     # below are equivalent.
     Company.objects.annotate(num_products=Count('products'))
@@ -122,6 +133,8 @@ and describe the operation.
    will need to be reloaded::
 
        reporter = Reporters.objects.get(pk=reporter.pk)
+       # Or, more succinctly:
+       reporter.refresh_from_db()
 
 As well as being used in operations on single instances as above, ``F()`` can
 be used on ``QuerySets`` of object instances, with ``update()``. This reduces
@@ -356,7 +369,10 @@ boolean, or string within an expression, you can wrap that value within a
 
 You will rarely need to use ``Value()`` directly. When you write the expression
 ``F('field') + 1``, Django implicitly wraps the ``1`` in a ``Value()``,
-allowing simple values to be used in more complex expressions.
+allowing simple values to be used in more complex expressions. You will need to
+use ``Value()`` when you want to pass a string to an expression. Most
+expressions interpret a string argument as the name of a field, like
+``Lower('name')``.
 
 The ``value`` argument describes the value to be included in the expression,
 such as ``1``, ``True``, or ``None``. Django knows how to convert these Python

+ 4 - 0
docs/releases/1.9.txt

@@ -542,6 +542,10 @@ Models
 * Added a new model field check that makes sure
   :attr:`~django.db.models.Field.default` is a valid value.
 
+* :doc:`Query expressions </ref/models/expressions>` can now be used when
+  creating new model instances using ``save()``, ``create()``, and
+  ``bulk_create()``.
+
 Requests and Responses
 ^^^^^^^^^^^^^^^^^^^^^^
 

+ 11 - 0
tests/bulk_create/tests.py

@@ -3,6 +3,8 @@ from __future__ import unicode_literals
 from operator import attrgetter
 
 from django.db import connection
+from django.db.models import Value
+from django.db.models.functions import Lower
 from django.test import (
     TestCase, override_settings, skipIfDBFeature, skipUnlessDBFeature,
 )
@@ -183,3 +185,12 @@ class BulkCreateTests(TestCase):
         TwoFields.objects.all().delete()
         with self.assertNumQueries(1):
             TwoFields.objects.bulk_create(objs, len(objs))
+
+    @skipUnlessDBFeature('has_bulk_insert')
+    def test_bulk_insert_expressions(self):
+        Restaurant.objects.bulk_create([
+            Restaurant(name="Sam's Shake Shack"),
+            Restaurant(name=Lower(Value("Betty's Beetroot Bar")))
+        ])
+        bbb = Restaurant.objects.filter(name="betty's beetroot bar")
+        self.assertEqual(bbb.count(), 1)

+ 42 - 1
tests/expressions/tests.py

@@ -249,6 +249,32 @@ class BasicExpressionsTests(TestCase):
         test_gmbh = Company.objects.get(pk=test_gmbh.pk)
         self.assertEqual(test_gmbh.num_employees, 36)
 
+    def test_new_object_save(self):
+        # We should be able to use Funcs when inserting new data
+        test_co = Company(
+            name=Lower(Value("UPPER")), num_employees=32, num_chairs=1,
+            ceo=Employee.objects.create(firstname="Just", lastname="Doit", salary=30),
+        )
+        test_co.save()
+        test_co.refresh_from_db()
+        self.assertEqual(test_co.name, "upper")
+
+    def test_new_object_create(self):
+        test_co = Company.objects.create(
+            name=Lower(Value("UPPER")), num_employees=32, num_chairs=1,
+            ceo=Employee.objects.create(firstname="Just", lastname="Doit", salary=30),
+        )
+        test_co.refresh_from_db()
+        self.assertEqual(test_co.name, "upper")
+
+    def test_object_create_with_aggregate(self):
+        # Aggregates are not allowed when inserting new data
+        with self.assertRaisesMessage(FieldError, 'Aggregate functions are not allowed in this query'):
+            Company.objects.create(
+                name='Company', num_employees=Max(Value(1)), num_chairs=1,
+                ceo=Employee.objects.create(firstname="Just", lastname="Doit", salary=30),
+            )
+
     def test_object_update_fk(self):
         # F expressions cannot be used to update attributes which are foreign
         # keys, or attributes which involve joins.
@@ -272,7 +298,22 @@ class BasicExpressionsTests(TestCase):
             ceo=test_gmbh.ceo
         )
         acme.num_employees = F("num_employees") + 16
-        self.assertRaises(TypeError, acme.save)
+        msg = (
+            'Failed to insert expression "Col(expressions_company, '
+            'expressions.Company.num_employees) + Value(16)" on '
+            'expressions.Company.num_employees. F() expressions can only be '
+            'used to update, not to insert.'
+        )
+        self.assertRaisesMessage(ValueError, msg, acme.save)
+
+        acme.num_employees = 12
+        acme.name = Lower(F('name'))
+        msg = (
+            'Failed to insert expression "Lower(Col(expressions_company, '
+            'expressions.Company.name))" on expressions.Company.name. F() '
+            'expressions can only be used to update, not to insert.'
+        )
+        self.assertRaisesMessage(ValueError, msg, acme.save)
 
     def test_ticket_11722_iexact_lookup(self):
         Employee.objects.create(firstname="John", lastname="Doe")

+ 6 - 1
tests/model_fields/tests.py

@@ -98,8 +98,13 @@ class BasicFieldTests(test.TestCase):
         self.assertTrue(instance.id)
         # Set field to object on saved instance
         instance.size = instance
+        msg = (
+            "Tried to update field model_fields.FloatModel.size with a model "
+            "instance, <FloatModel: FloatModel object>. Use a value "
+            "compatible with FloatField."
+        )
         with transaction.atomic():
-            with self.assertRaises(TypeError):
+            with self.assertRaisesMessage(TypeError, msg):
                 instance.save()
         # Try setting field to object on retrieved object
         obj = FloatModel.objects.get(pk=instance.id)