Browse Source

Refs #28900 -- Made SELECT respect the order specified by values(*selected).

Previously the order was always extra_fields + model_fields + annotations with
respective local ordering inferred from the insertion order of *selected.

This commits introduces a new `Query.selected` propery that keeps tracks of the
global select order as specified by on values assignment. This is crucial
feature to allow the combination of queries mixing annotations and table
references.

It also allows the removal of the re-ordering shenanigans perform by
ValuesListIterable in order to re-map the tuples returned from the database
backend to the order specified by values_list() as they'll be in the right
order at query compilation time.

Refs #28553 as the initially reported issue that was only partially fixed
for annotations by d6b6e5d0fd4e6b6d0183b4cf6e4bd4f9afc7bf67.

Thanks Mariusz Felisiak and Sarah Boyce for review.
Simon Charette 2 years ago
parent
commit
65ad4ade74

+ 9 - 28
django/db/models/query.py

@@ -200,12 +200,15 @@ class ValuesIterable(BaseIterable):
         query = queryset.query
         compiler = query.get_compiler(queryset.db)
 
-        # extra(select=...) cols are always at the start of the row.
-        names = [
-            *query.extra_select,
-            *query.values_select,
-            *query.annotation_select,
-        ]
+        if query.selected:
+            names = list(query.selected)
+        else:
+            # extra(select=...) cols are always at the start of the row.
+            names = [
+                *query.extra_select,
+                *query.values_select,
+                *query.annotation_select,
+            ]
         indexes = range(len(names))
         for row in compiler.results_iter(
             chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
@@ -223,28 +226,6 @@ class ValuesListIterable(BaseIterable):
         queryset = self.queryset
         query = queryset.query
         compiler = query.get_compiler(queryset.db)
-
-        if queryset._fields:
-            # extra(select=...) cols are always at the start of the row.
-            names = [
-                *query.extra_select,
-                *query.values_select,
-                *query.annotation_select,
-            ]
-            fields = [
-                *queryset._fields,
-                *(f for f in query.annotation_select if f not in queryset._fields),
-            ]
-            if fields != names:
-                # Reorder according to fields.
-                index_map = {name: idx for idx, name in enumerate(names)}
-                rowfactory = operator.itemgetter(*[index_map[f] for f in fields])
-                return map(
-                    rowfactory,
-                    compiler.results_iter(
-                        chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
-                    ),
-                )
         return compiler.results_iter(
             tuple_expected=True,
             chunked_fetch=self.chunked_fetch,

+ 30 - 15
django/db/models/sql/compiler.py

@@ -247,11 +247,6 @@ class SQLCompiler:
         select = []
         klass_info = None
         annotations = {}
-        select_idx = 0
-        for alias, (sql, params) in self.query.extra_select.items():
-            annotations[alias] = select_idx
-            select.append((RawSQL(sql, params), alias))
-            select_idx += 1
         assert not (self.query.select and self.query.default_cols)
         select_mask = self.query.get_select_mask()
         if self.query.default_cols:
@@ -261,19 +256,39 @@ class SQLCompiler:
             # any model.
             cols = self.query.select
         if cols:
-            select_list = []
-            for col in cols:
-                select_list.append(select_idx)
-                select.append((col, None))
-                select_idx += 1
             klass_info = {
                 "model": self.query.model,
-                "select_fields": select_list,
+                "select_fields": list(
+                    range(
+                        len(self.query.extra_select),
+                        len(self.query.extra_select) + len(cols),
+                    )
+                ),
             }
-        for alias, annotation in self.query.annotation_select.items():
-            annotations[alias] = select_idx
-            select.append((annotation, alias))
-            select_idx += 1
+        selected = []
+        if self.query.selected is None:
+            selected = [
+                *(
+                    (alias, RawSQL(*args))
+                    for alias, args in self.query.extra_select.items()
+                ),
+                *((None, col) for col in cols),
+                *self.query.annotation_select.items(),
+            ]
+        else:
+            for alias, expression in self.query.selected.items():
+                # Reference to an annotation.
+                if isinstance(expression, str):
+                    expression = self.query.annotations[expression]
+                # Reference to a column.
+                elif isinstance(expression, int):
+                    expression = cols[expression]
+                selected.append((alias, expression))
+
+        for select_idx, (alias, expression) in enumerate(selected):
+            if alias:
+                annotations[alias] = select_idx
+            select.append((expression, alias))
 
         if self.query.select_related:
             related_klass_infos = self.get_related_selections(select, select_mask)

+ 33 - 14
django/db/models/sql/query.py

@@ -26,6 +26,7 @@ from django.db.models.expressions import (
     Exists,
     F,
     OuterRef,
+    RawSQL,
     Ref,
     ResolvedOuterRef,
     Value,
@@ -265,6 +266,7 @@ class Query(BaseExpression):
     # Holds the selects defined by a call to values() or values_list()
     # excluding annotation_select and extra_select.
     values_select = ()
+    selected = None
 
     # SQL annotation-related attributes.
     annotation_select_mask = None
@@ -584,6 +586,7 @@ class Query(BaseExpression):
         else:
             outer_query = self
             self.select = ()
+            self.selected = None
             self.default_cols = False
             self.extra = {}
             if self.annotations:
@@ -1194,13 +1197,10 @@ class Query(BaseExpression):
         if select:
             self.append_annotation_mask([alias])
         else:
-            annotation_mask = (
-                value
-                for value in dict.fromkeys(self.annotation_select)
-                if value != alias
-            )
-            self.set_annotation_mask(annotation_mask)
+            self.set_annotation_mask(set(self.annotation_select).difference({alias}))
         self.annotations[alias] = annotation
+        if self.selected:
+            self.selected[alias] = alias
 
     def resolve_expression(self, query, *args, **kwargs):
         clone = self.clone()
@@ -2153,6 +2153,7 @@ class Query(BaseExpression):
         self.select_related = False
         self.set_extra_mask(())
         self.set_annotation_mask(())
+        self.selected = None
 
     def clear_select_fields(self):
         """
@@ -2162,10 +2163,12 @@ class Query(BaseExpression):
         """
         self.select = ()
         self.values_select = ()
+        self.selected = None
 
     def add_select_col(self, col, name):
         self.select += (col,)
         self.values_select += (name,)
+        self.selected[name] = len(self.select) - 1
 
     def set_select(self, cols):
         self.default_cols = False
@@ -2416,12 +2419,23 @@ class Query(BaseExpression):
         if names is None:
             self.annotation_select_mask = None
         else:
-            self.annotation_select_mask = list(dict.fromkeys(names))
+            self.annotation_select_mask = set(names)
+            if self.selected:
+                # Prune the masked annotations.
+                self.selected = {
+                    key: value
+                    for key, value in self.selected.items()
+                    if not isinstance(value, str)
+                    or value in self.annotation_select_mask
+                }
+                # Append the unmasked annotations.
+                for name in names:
+                    self.selected[name] = name
         self._annotation_select_cache = None
 
     def append_annotation_mask(self, names):
         if self.annotation_select_mask is not None:
-            self.set_annotation_mask((*self.annotation_select_mask, *names))
+            self.set_annotation_mask(self.annotation_select_mask.union(names))
 
     def set_extra_mask(self, names):
         """
@@ -2440,6 +2454,7 @@ class Query(BaseExpression):
         self.clear_select_fields()
         self.has_select_fields = True
 
+        selected = {}
         if fields:
             field_names = []
             extra_names = []
@@ -2448,13 +2463,16 @@ class Query(BaseExpression):
                 # Shortcut - if there are no extra or annotations, then
                 # the values() clause must be just field names.
                 field_names = list(fields)
+                selected = dict(zip(fields, range(len(fields))))
             else:
                 self.default_cols = False
                 for f in fields:
-                    if f in self.extra_select:
+                    if extra := self.extra_select.get(f):
                         extra_names.append(f)
+                        selected[f] = RawSQL(*extra)
                     elif f in self.annotation_select:
                         annotation_names.append(f)
+                        selected[f] = f
                     elif f in self.annotations:
                         raise FieldError(
                             f"Cannot select the '{f}' alias. Use annotate() to "
@@ -2466,13 +2484,13 @@ class Query(BaseExpression):
                         # `f` is not resolvable.
                         if self.annotation_select:
                             self.names_to_path(f.split(LOOKUP_SEP), self.model._meta)
+                        selected[f] = len(field_names)
                         field_names.append(f)
             self.set_extra_mask(extra_names)
             self.set_annotation_mask(annotation_names)
-            selected = frozenset(field_names + extra_names + annotation_names)
         else:
             field_names = [f.attname for f in self.model._meta.concrete_fields]
-            selected = frozenset(field_names)
+            selected = dict.fromkeys(field_names, None)
         # Selected annotations must be known before setting the GROUP BY
         # clause.
         if self.group_by is True:
@@ -2495,6 +2513,7 @@ class Query(BaseExpression):
 
         self.values_select = tuple(field_names)
         self.add_fields(field_names, True)
+        self.selected = selected if fields else None
 
     @property
     def annotation_select(self):
@@ -2508,9 +2527,9 @@ class Query(BaseExpression):
             return {}
         elif self.annotation_select_mask is not None:
             self._annotation_select_cache = {
-                k: self.annotations[k]
-                for k in self.annotation_select_mask
-                if k in self.annotations
+                k: v
+                for k, v in self.annotations.items()
+                if k in self.annotation_select_mask
             }
             return self._annotation_select_cache
         else:

+ 10 - 0
docs/ref/models/querysets.txt

@@ -745,6 +745,11 @@ You can also refer to fields on related models with reverse relations through
     ``"true"``, ``"false"``, and ``"null"`` strings for
     :class:`~django.db.models.JSONField` key transforms.
 
+.. versionchanged:: 5.2
+
+    The ``SELECT`` clause generated when using ``values()`` was updated to
+    respect the order of the specified ``*fields`` and ``**expressions``.
+
 ``values_list()``
 ~~~~~~~~~~~~~~~~~
 
@@ -835,6 +840,11 @@ not having any author:
     ``"true"``, ``"false"``, and ``"null"`` strings for
     :class:`~django.db.models.JSONField` key transforms.
 
+.. versionchanged:: 5.2
+
+    The ``SELECT`` clause generated when using ``values_list()`` was updated to
+    respect the order of the specified ``*fields``.
+
 ``dates()``
 ~~~~~~~~~~~
 

+ 7 - 1
docs/releases/5.2.txt

@@ -195,7 +195,13 @@ Migrations
 Models
 ~~~~~~
 
-* ...
+* The ``SELECT`` clause generated when using
+  :meth:`QuerySet.values()<django.db.models.query.QuerySet.values>` and
+  :meth:`~django.db.models.query.QuerySet.values_list` now matches the
+  specified order of the referenced expressions. Previously the order was based
+  of a set of counterintuitive rules which made query combination through
+  methods such as
+  :meth:`QuerySet.union()<django.db.models.query.QuerySet.union>` unpredictable.
 
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~

+ 1 - 0
docs/spelling_wordlist

@@ -96,6 +96,7 @@ contenttypes
 contrib
 coroutine
 coroutines
+counterintuitive
 criticals
 cron
 crontab

+ 10 - 0
tests/annotations/tests.py

@@ -568,6 +568,16 @@ class NonAggregateAnnotationTestCase(TestCase):
         self.assertEqual(book["other_rating"], 4)
         self.assertEqual(book["other_isbn"], "155860191")
 
+    def test_values_fields_annotations_order(self):
+        qs = Book.objects.annotate(other_rating=F("rating") - 1).values(
+            "other_rating", "rating"
+        )
+        book = qs.get(pk=self.b1.pk)
+        self.assertEqual(
+            list(book.items()),
+            [("other_rating", self.b1.rating - 1), ("rating", self.b1.rating)],
+        )
+
     def test_values_with_pk_annotation(self):
         # annotate references a field in values() with pk
         publishers = Publisher.objects.values("id", "book__rating").annotate(

+ 2 - 2
tests/postgres_tests/test_array.py

@@ -466,8 +466,8 @@ class TestQuerying(PostgreSQLTestCase):
                 ],
             )
         sql = ctx[0]["sql"]
-        self.assertIn("GROUP BY 2", sql)
-        self.assertIn("ORDER BY 2", sql)
+        self.assertIn("GROUP BY 1", sql)
+        self.assertIn("ORDER BY 1", sql)
 
     def test_order_by_arrayagg_index(self):
         qs = (

+ 22 - 1
tests/queries/test_qs_combinators.py

@@ -257,6 +257,23 @@ class QuerySetSetOperationTests(TestCase):
         )
         self.assertCountEqual(qs1.union(qs2), [(1, 0), (1, 2)])
 
+    def test_union_with_field_and_annotation_values(self):
+        qs1 = (
+            Number.objects.filter(num=1)
+            .annotate(
+                zero=Value(0, IntegerField()),
+            )
+            .values_list("num", "zero")
+        )
+        qs2 = (
+            Number.objects.filter(num=2)
+            .annotate(
+                zero=Value(0, IntegerField()),
+            )
+            .values_list("zero", "num")
+        )
+        self.assertCountEqual(qs1.union(qs2), [(1, 0), (0, 2)])
+
     def test_union_with_extra_and_values_list(self):
         qs1 = (
             Number.objects.filter(num=1)
@@ -265,7 +282,11 @@ class QuerySetSetOperationTests(TestCase):
             )
             .values_list("num", "count")
         )
-        qs2 = Number.objects.filter(num=2).extra(select={"count": 1})
+        qs2 = (
+            Number.objects.filter(num=2)
+            .extra(select={"count": 1})
+            .values_list("num", "count")
+        )
         self.assertCountEqual(qs1.union(qs2), [(1, 0), (2, 1)])
 
     def test_union_with_values_list_on_annotated_and_unannotated(self):

+ 1 - 1
tests/queries/tests.py

@@ -2200,7 +2200,7 @@ class Queries6Tests(TestCase):
                 {"tag_per_parent__max": 2},
             )
         sql = captured_queries[0]["sql"]
-        self.assertIn("AS %s" % connection.ops.quote_name("col1"), sql)
+        self.assertIn("AS %s" % connection.ops.quote_name("parent"), sql)
 
     def test_xor_subquery(self):
         self.assertSequenceEqual(