Browse Source

Fixed #21204 -- Tracked field deferrals by field instead of models.

This ensures field deferral works properly when a model is involved
more than once in the same query with a distinct deferral mask.
Simon Charette 2 years ago
parent
commit
b3db6c8dcb

+ 4 - 4
django/db/models/query_utils.py

@@ -259,7 +259,7 @@ class RegisterLookupMixin:
         cls._clear_cached_lookups()
 
 
-def select_related_descend(field, restricted, requested, load_fields, reverse=False):
+def select_related_descend(field, restricted, requested, select_mask, reverse=False):
     """
     Return True if this field should be used to descend deeper for
     select_related() purposes. Used by both the query construction code
@@ -271,7 +271,7 @@ def select_related_descend(field, restricted, requested, load_fields, reverse=Fa
      * restricted - a boolean field, indicating if the field list has been
        manually restricted using a requested clause)
      * requested - The select_related() dictionary.
-     * load_fields - the set of fields to be loaded on this model
+     * select_mask - the dictionary of selected fields.
      * reverse - boolean, True if we are checking a reverse select related
     """
     if not field.remote_field:
@@ -287,9 +287,9 @@ def select_related_descend(field, restricted, requested, load_fields, reverse=Fa
         return False
     if (
         restricted
-        and load_fields
+        and select_mask
         and field.name in requested
-        and field.attname not in load_fields
+        and field not in select_mask
     ):
         raise FieldError(
             f"Field {field.model._meta.object_name}.{field.name} cannot be both "

+ 28 - 24
django/db/models/sql/compiler.py

@@ -256,8 +256,9 @@ class SQLCompiler:
             select.append((RawSQL(sql, params), alias))
             select_idx += 1
         assert not (self.query.select and self.query.default_cols)
+        select_mask = self.query.get_select_mask()
         if self.query.default_cols:
-            cols = self.get_default_columns()
+            cols = self.get_default_columns(select_mask)
         else:
             # self.query.select is a special case. These columns never go to
             # any model.
@@ -278,7 +279,7 @@ class SQLCompiler:
             select_idx += 1
 
         if self.query.select_related:
-            related_klass_infos = self.get_related_selections(select)
+            related_klass_infos = self.get_related_selections(select, select_mask)
             klass_info["related_klass_infos"] = related_klass_infos
 
             def get_select_from_parent(klass_info):
@@ -870,7 +871,9 @@ class SQLCompiler:
             # Finally do cleanup - get rid of the joins we created above.
             self.query.reset_refcounts(refcounts_before)
 
-    def get_default_columns(self, start_alias=None, opts=None, from_parent=None):
+    def get_default_columns(
+        self, select_mask, start_alias=None, opts=None, from_parent=None
+    ):
         """
         Compute the default columns for selecting every field in the base
         model. Will sometimes be called to pull in related models (e.g. via
@@ -886,7 +889,6 @@ class SQLCompiler:
         if opts is None:
             if (opts := self.query.get_meta()) is None:
                 return result
-        only_load = self.deferred_to_columns()
         start_alias = start_alias or self.query.get_initial_alias()
         # The 'seen_models' is used to optimize checking the needed parent
         # alias for a given field. This also includes None -> start_alias to
@@ -912,7 +914,7 @@ class SQLCompiler:
                 # parent model data is already present in the SELECT clause,
                 # and we want to avoid reloading the same data again.
                 continue
-            if field.model in only_load and field.attname not in only_load[field.model]:
+            if select_mask and field not in select_mask:
                 continue
             alias = self.query.join_parent_model(opts, model, start_alias, seen_models)
             column = field.get_col(alias)
@@ -1063,6 +1065,7 @@ class SQLCompiler:
     def get_related_selections(
         self,
         select,
+        select_mask,
         opts=None,
         root_alias=None,
         cur_depth=1,
@@ -1095,7 +1098,6 @@ class SQLCompiler:
         if not opts:
             opts = self.query.get_meta()
             root_alias = self.query.get_initial_alias()
-        only_load = self.deferred_to_columns()
 
         # Setup for the case when only particular related fields should be
         # included in the related selection.
@@ -1109,7 +1111,6 @@ class SQLCompiler:
             klass_info["related_klass_infos"] = related_klass_infos
 
         for f in opts.fields:
-            field_model = f.model._meta.concrete_model
             fields_found.add(f.name)
 
             if restricted:
@@ -1129,10 +1130,9 @@ class SQLCompiler:
             else:
                 next = False
 
-            if not select_related_descend(
-                f, restricted, requested, only_load.get(field_model)
-            ):
+            if not select_related_descend(f, restricted, requested, select_mask):
                 continue
+            related_select_mask = select_mask.get(f) or {}
             klass_info = {
                 "model": f.remote_field.model,
                 "field": f,
@@ -1148,7 +1148,7 @@ class SQLCompiler:
             _, _, _, joins, _, _ = self.query.setup_joins([f.name], opts, root_alias)
             alias = joins[-1]
             columns = self.get_default_columns(
-                start_alias=alias, opts=f.remote_field.model._meta
+                related_select_mask, start_alias=alias, opts=f.remote_field.model._meta
             )
             for col in columns:
                 select_fields.append(len(select))
@@ -1156,6 +1156,7 @@ class SQLCompiler:
             klass_info["select_fields"] = select_fields
             next_klass_infos = self.get_related_selections(
                 select,
+                related_select_mask,
                 f.remote_field.model._meta,
                 alias,
                 cur_depth + 1,
@@ -1171,8 +1172,9 @@ class SQLCompiler:
                 if o.field.unique and not o.many_to_many
             ]
             for f, model in related_fields:
+                related_select_mask = select_mask.get(f) or {}
                 if not select_related_descend(
-                    f, restricted, requested, only_load.get(model), reverse=True
+                    f, restricted, requested, related_select_mask, reverse=True
                 ):
                     continue
 
@@ -1195,7 +1197,10 @@ class SQLCompiler:
                 related_klass_infos.append(klass_info)
                 select_fields = []
                 columns = self.get_default_columns(
-                    start_alias=alias, opts=model._meta, from_parent=opts.model
+                    related_select_mask,
+                    start_alias=alias,
+                    opts=model._meta,
+                    from_parent=opts.model,
                 )
                 for col in columns:
                     select_fields.append(len(select))
@@ -1203,7 +1208,13 @@ class SQLCompiler:
                 klass_info["select_fields"] = select_fields
                 next = requested.get(f.related_query_name(), {})
                 next_klass_infos = self.get_related_selections(
-                    select, model._meta, alias, cur_depth + 1, next, restricted
+                    select,
+                    related_select_mask,
+                    model._meta,
+                    alias,
+                    cur_depth + 1,
+                    next,
+                    restricted,
                 )
                 get_related_klass_infos(klass_info, next_klass_infos)
 
@@ -1239,7 +1250,9 @@ class SQLCompiler:
                     }
                     related_klass_infos.append(klass_info)
                     select_fields = []
+                    field_select_mask = select_mask.get((name, f)) or {}
                     columns = self.get_default_columns(
+                        field_select_mask,
                         start_alias=alias,
                         opts=model._meta,
                         from_parent=opts.model,
@@ -1251,6 +1264,7 @@ class SQLCompiler:
                     next_requested = requested.get(name, {})
                     next_klass_infos = self.get_related_selections(
                         select,
+                        field_select_mask,
                         opts=model._meta,
                         root_alias=alias,
                         cur_depth=cur_depth + 1,
@@ -1377,16 +1391,6 @@ class SQLCompiler:
             )
         return result
 
-    def deferred_to_columns(self):
-        """
-        Convert the self.deferred_loading data structure to mapping of table
-        names to sets of column names which are to be loaded. Return the
-        dictionary.
-        """
-        columns = {}
-        self.query.deferred_to_data(columns)
-        return columns
-
     def get_converters(self, expressions):
         converters = {}
         for i, expression in enumerate(expressions):

+ 63 - 90
django/db/models/sql/query.py

@@ -718,7 +718,61 @@ class Query(BaseExpression):
         self.order_by = rhs.order_by or self.order_by
         self.extra_order_by = rhs.extra_order_by or self.extra_order_by
 
-    def deferred_to_data(self, target):
+    def _get_defer_select_mask(self, opts, mask, select_mask=None):
+        if select_mask is None:
+            select_mask = {}
+        select_mask[opts.pk] = {}
+        # All concrete fields that are not part of the defer mask must be
+        # loaded. If a relational field is encountered it gets added to the
+        # mask for it be considered if `select_related` and the cycle continues
+        # by recursively caling this function.
+        for field in opts.concrete_fields:
+            field_mask = mask.pop(field.name, None)
+            if field_mask is None:
+                select_mask.setdefault(field, {})
+            elif field_mask:
+                if not field.is_relation:
+                    raise FieldError(next(iter(field_mask)))
+                field_select_mask = select_mask.setdefault(field, {})
+                related_model = field.remote_field.model._meta.concrete_model
+                self._get_defer_select_mask(
+                    related_model._meta, field_mask, field_select_mask
+                )
+        # Remaining defer entries must be references to reverse relationships.
+        # The following code is expected to raise FieldError if it encounters
+        # a malformed defer entry.
+        for field_name, field_mask in mask.items():
+            if filtered_relation := self._filtered_relations.get(field_name):
+                relation = opts.get_field(filtered_relation.relation_name)
+                field_select_mask = select_mask.setdefault((field_name, relation), {})
+                field = relation.field
+            else:
+                field = opts.get_field(field_name).field
+                field_select_mask = select_mask.setdefault(field, {})
+            related_model = field.model._meta.concrete_model
+            self._get_defer_select_mask(
+                related_model._meta, field_mask, field_select_mask
+            )
+        return select_mask
+
+    def _get_only_select_mask(self, opts, mask, select_mask=None):
+        if select_mask is None:
+            select_mask = {}
+        select_mask[opts.pk] = {}
+        # Only include fields mentioned in the mask.
+        for field_name, field_mask in mask.items():
+            field = opts.get_field(field_name)
+            field_select_mask = select_mask.setdefault(field, {})
+            if field_mask:
+                if not field.is_relation:
+                    raise FieldError(next(iter(field_mask)))
+                related_model = field.remote_field.model._meta.concrete_model
+                self._get_only_select_mask(
+                    related_model._meta, field_mask, field_select_mask
+                )
+        return select_mask
+
+    def get_select_mask(self):
         """
         Convert the self.deferred_loading data structure to an alternate data
         structure, describing the field that *will* be loaded. This is used to
@@ -726,81 +780,19 @@ class Query(BaseExpression):
         QuerySet class to work out which fields are being initialized on each
         model. Models that have all their fields included aren't mentioned in
         the result, only those that have field restrictions in place.
-
-        The "target" parameter is the instance that is populated (in place).
         """
         field_names, defer = self.deferred_loading
         if not field_names:
-            return
-        orig_opts = self.get_meta()
-        seen = {}
-        must_include = {orig_opts.concrete_model: {orig_opts.pk}}
+            return {}
+        mask = {}
         for field_name in field_names:
-            parts = field_name.split(LOOKUP_SEP)
-            cur_model = self.model._meta.concrete_model
-            opts = orig_opts
-            for name in parts[:-1]:
-                old_model = cur_model
-                if name in self._filtered_relations:
-                    name = self._filtered_relations[name].relation_name
-                source = opts.get_field(name)
-                if is_reverse_o2o(source):
-                    cur_model = source.related_model
-                else:
-                    cur_model = source.remote_field.model
-                cur_model = cur_model._meta.concrete_model
-                opts = cur_model._meta
-                # Even if we're "just passing through" this model, we must add
-                # both the current model's pk and the related reference field
-                # (if it's not a reverse relation) to the things we select.
-                if not is_reverse_o2o(source):
-                    must_include[old_model].add(source)
-                add_to_dict(must_include, cur_model, opts.pk)
-            field = opts.get_field(parts[-1])
-            is_reverse_object = field.auto_created and not field.concrete
-            model = field.related_model if is_reverse_object else field.model
-            model = model._meta.concrete_model
-            if model == opts.model:
-                model = cur_model
-            if not is_reverse_o2o(field):
-                add_to_dict(seen, model, field)
-
+            part_mask = mask
+            for part in field_name.split(LOOKUP_SEP):
+                part_mask = part_mask.setdefault(part, {})
+        opts = self.get_meta()
         if defer:
-            # We need to load all fields for each model, except those that
-            # appear in "seen" (for all models that appear in "seen"). The only
-            # slight complexity here is handling fields that exist on parent
-            # models.
-            workset = {}
-            for model, values in seen.items():
-                for field in model._meta.local_fields:
-                    if field not in values:
-                        m = field.model._meta.concrete_model
-                        add_to_dict(workset, m, field)
-            for model, values in must_include.items():
-                # If we haven't included a model in workset, we don't add the
-                # corresponding must_include fields for that model, since an
-                # empty set means "include all fields". That's why there's no
-                # "else" branch here.
-                if model in workset:
-                    workset[model].update(values)
-            for model, fields in workset.items():
-                target[model] = {f.attname for f in fields}
-        else:
-            for model, values in must_include.items():
-                if model in seen:
-                    seen[model].update(values)
-                else:
-                    # As we've passed through this model, but not explicitly
-                    # included any fields, we have to make sure it's mentioned
-                    # so that only the "must include" fields are pulled in.
-                    seen[model] = values
-            # Now ensure that every model in the inheritance chain is mentioned
-            # in the parent list. Again, it must be mentioned to ensure that
-            # only "must include" fields are pulled in.
-            for model in orig_opts.get_parent_list():
-                seen.setdefault(model, set())
-            for model, fields in seen.items():
-                target[model] = {f.attname for f in fields}
+            return self._get_defer_select_mask(opts, mask)
+        return self._get_only_select_mask(opts, mask)
 
     def table_alias(self, table_name, create=False, filtered_relation=None):
         """
@@ -2583,25 +2575,6 @@ def get_order_dir(field, default="ASC"):
     return field, dirn[0]
 
 
-def add_to_dict(data, key, value):
-    """
-    Add "value" to the set of values for "key", whether or not "key" already
-    exists.
-    """
-    if key in data:
-        data[key].add(value)
-    else:
-        data[key] = {value}
-
-
-def is_reverse_o2o(field):
-    """
-    Check if the given field is reverse-o2o. The field is expected to be some
-    sort of relation field or related object.
-    """
-    return field.is_relation and field.one_to_one and not field.concrete
-
-
 class JoinPromoter:
     """
     A class to abstract away join promotion problems for complex filter

+ 4 - 0
tests/defer/tests.py

@@ -290,6 +290,8 @@ class InvalidDeferTests(SimpleTestCase):
         msg = "Primary has no field named 'missing'"
         with self.assertRaisesMessage(FieldDoesNotExist, msg):
             list(Primary.objects.defer("missing"))
+        with self.assertRaisesMessage(FieldError, "missing"):
+            list(Primary.objects.defer("value__missing"))
         msg = "Secondary has no field named 'missing'"
         with self.assertRaisesMessage(FieldDoesNotExist, msg):
             list(Primary.objects.defer("related__missing"))
@@ -298,6 +300,8 @@ class InvalidDeferTests(SimpleTestCase):
         msg = "Primary has no field named 'missing'"
         with self.assertRaisesMessage(FieldDoesNotExist, msg):
             list(Primary.objects.only("missing"))
+        with self.assertRaisesMessage(FieldError, "missing"):
+            list(Primary.objects.only("value__missing"))
         msg = "Secondary has no field named 'missing'"
         with self.assertRaisesMessage(FieldDoesNotExist, msg):
             list(Primary.objects.only("related__missing"))

+ 22 - 2
tests/defer_regress/tests.py

@@ -246,8 +246,6 @@ class DeferRegressionTest(TestCase):
         )
         self.assertEqual(len(qs), 1)
 
-
-class DeferAnnotateSelectRelatedTest(TestCase):
     def test_defer_annotate_select_related(self):
         location = Location.objects.create()
         Request.objects.create(location=location)
@@ -276,6 +274,28 @@ class DeferAnnotateSelectRelatedTest(TestCase):
             list,
         )
 
+    def test_common_model_different_mask(self):
+        child = Child.objects.create(name="Child", value=42)
+        second_child = Child.objects.create(name="Second", value=64)
+        Leaf.objects.create(child=child, second_child=second_child)
+        with self.assertNumQueries(1):
+            leaf = (
+                Leaf.objects.select_related("child", "second_child")
+                .defer("child__name", "second_child__value")
+                .get()
+            )
+            self.assertEqual(leaf.child, child)
+            self.assertEqual(leaf.second_child, second_child)
+        self.assertEqual(leaf.child.get_deferred_fields(), {"name"})
+        self.assertEqual(leaf.second_child.get_deferred_fields(), {"value"})
+        with self.assertNumQueries(0):
+            self.assertEqual(leaf.child.value, 42)
+            self.assertEqual(leaf.second_child.name, "Second")
+        with self.assertNumQueries(1):
+            self.assertEqual(leaf.child.name, "Child")
+        with self.assertNumQueries(1):
+            self.assertEqual(leaf.second_child.value, 64)
+
 
 class DeferDeletionSignalsTests(TestCase):
     senders = [Item, Proxy]

+ 0 - 6
tests/queries/tests.py

@@ -3594,12 +3594,6 @@ class WhereNodeTest(SimpleTestCase):
 
 
 class QuerySetExceptionTests(SimpleTestCase):
-    def test_iter_exceptions(self):
-        qs = ExtraInfo.objects.only("author")
-        msg = "'ManyToOneRel' object has no attribute 'attname'"
-        with self.assertRaisesMessage(AttributeError, msg):
-            list(qs)
-
     def test_invalid_order_by(self):
         msg = "Cannot resolve keyword '*' into field. Choices are: created, id, name"
         with self.assertRaisesMessage(FieldError, msg):