Browse Source

Fixed #28897 -- Fixed QuerySet.update() on querysets ordered by annotations.

David Wobrock 2 years ago
parent
commit
3ef37a5245
3 changed files with 38 additions and 8 deletions
  1. 1 0
      django/db/backends/mysql/features.py
  2. 14 0
      django/db/models/query.py
  3. 23 8
      tests/update/tests.py

+ 1 - 0
django/db/backends/mysql/features.py

@@ -113,6 +113,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
             "related fields.": {
                 "update.tests.AdvancedTests."
                 "test_update_ordered_by_inline_m2m_annotation",
+                "update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation",
             },
         }
         if "ONLY_FULL_GROUP_BY" in self.connection.sql_mode:

+ 14 - 0
django/db/models/query.py

@@ -1169,6 +1169,20 @@ class QuerySet:
         self._for_write = True
         query = self.query.chain(sql.UpdateQuery)
         query.add_update_values(kwargs)
+
+        # Inline annotations in order_by(), if possible.
+        new_order_by = []
+        for col in query.order_by:
+            if annotation := query.annotations.get(col):
+                if getattr(annotation, "contains_aggregate", False):
+                    raise exceptions.FieldError(
+                        f"Cannot update when ordering by an aggregate: {annotation}"
+                    )
+                new_order_by.append(annotation)
+            else:
+                new_order_by.append(col)
+        query.order_by = tuple(new_order_by)
+
         # Clear any annotations so that they won't be present in subqueries.
         query.annotations = {}
         with transaction.mark_for_rollback_on_error(using=self.db):

+ 23 - 8
tests/update/tests.py

@@ -225,6 +225,16 @@ class AdvancedTests(TestCase):
                             new_name=annotation,
                         ).update(name=F("new_name"))
 
+    def test_update_ordered_by_m2m_aggregation_annotation(self):
+        msg = (
+            "Cannot update when ordering by an aggregate: "
+            "Count(Col(update_bar_m2m_foo, update.Bar_m2m_foo.foo))"
+        )
+        with self.assertRaisesMessage(FieldError, msg):
+            Bar.objects.annotate(m2m_count=Count("m2m_foo")).order_by(
+                "m2m_count"
+            ).update(x=2)
+
     def test_update_ordered_by_inline_m2m_annotation(self):
         foo = Foo.objects.create(target="test")
         Bar.objects.create(foo=foo)
@@ -232,6 +242,13 @@ class AdvancedTests(TestCase):
         Bar.objects.order_by(Abs("m2m_foo")).update(x=2)
         self.assertEqual(Bar.objects.get().x, 2)
 
+    def test_update_ordered_by_m2m_annotation(self):
+        foo = Foo.objects.create(target="test")
+        Bar.objects.create(foo=foo)
+
+        Bar.objects.annotate(abs_id=Abs("m2m_foo")).order_by("abs_id").update(x=3)
+        self.assertEqual(Bar.objects.get().x, 3)
+
 
 @unittest.skipUnless(
     connection.vendor == "mysql",
@@ -259,14 +276,12 @@ class MySQLUpdateOrderByTest(TestCase):
                 self.assertEqual(updated, 2)
 
     def test_order_by_update_on_unique_constraint_annotation(self):
-        # Ordering by annotations is omitted because they cannot be resolved in
-        # .update().
-        with self.assertRaises(IntegrityError):
-            UniqueNumber.objects.annotate(number_inverse=F("number").desc(),).order_by(
-                "number_inverse"
-            ).update(
-                number=F("number") + 1,
-            )
+        updated = (
+            UniqueNumber.objects.annotate(number_inverse=F("number").desc())
+            .order_by("number_inverse")
+            .update(number=F("number") + 1)
+        )
+        self.assertEqual(updated, 2)
 
     def test_order_by_update_on_parent_unique_constraint(self):
         # Ordering by inherited fields is omitted because joined fields cannot