Browse Source

Fixed #27654 -- Propagated alters_data attribute to callables overridden in subclasses.

Thanks Shai Berger and Adam Johnson for reviews and the implementation
idea.
LightDiscord 2 years ago
parent
commit
e20c9eb60a

+ 2 - 1
django/contrib/contenttypes/fields.py

@@ -16,6 +16,7 @@ from django.db.models.fields.related import (
 from django.db.models.query_utils import PathInfo
 from django.db.models.sql import AND
 from django.db.models.sql.where import WhereNode
+from django.db.models.utils import AltersData
 from django.utils.functional import cached_property
 
 
@@ -560,7 +561,7 @@ def create_generic_related_manager(superclass, rel):
     specific to generic relations.
     """
 
-    class GenericRelatedObjectManager(superclass):
+    class GenericRelatedObjectManager(superclass, AltersData):
         def __init__(self, instance=None):
             super().__init__()
 

+ 2 - 2
django/db/models/base.py

@@ -48,7 +48,7 @@ from django.db.models.signals import (
     pre_init,
     pre_save,
 )
-from django.db.models.utils import make_model_tuple
+from django.db.models.utils import AltersData, make_model_tuple
 from django.utils.encoding import force_str
 from django.utils.hashable import make_hashable
 from django.utils.text import capfirst, get_text_list
@@ -456,7 +456,7 @@ class ModelState:
     fields_cache = ModelStateFieldsCacheDescriptor()
 
 
-class Model(metaclass=ModelBase):
+class Model(AltersData, metaclass=ModelBase):
     def __init__(self, *args, **kwargs):
         # Alias some things as locals to avoid repeat global lookups
         cls = self.__class__

+ 2 - 1
django/db/models/fields/files.py

@@ -10,10 +10,11 @@ from django.core.files.utils import validate_file_name
 from django.db.models import signals
 from django.db.models.fields import Field
 from django.db.models.query_utils import DeferredAttribute
+from django.db.models.utils import AltersData
 from django.utils.translation import gettext_lazy as _
 
 
-class FieldFile(File):
+class FieldFile(File, AltersData):
     def __init__(self, instance, field, name):
         super().__init__(None, name)
         self.instance = instance

+ 3 - 3
django/db/models/fields/related_descriptors.py

@@ -76,7 +76,7 @@ from django.db.models.functions import RowNumber
 from django.db.models.lookups import GreaterThan, LessThanOrEqual
 from django.db.models.query import QuerySet
 from django.db.models.query_utils import DeferredAttribute
-from django.db.models.utils import resolve_callables
+from django.db.models.utils import AltersData, resolve_callables
 from django.utils.functional import cached_property
 
 
@@ -635,7 +635,7 @@ def create_reverse_many_to_one_manager(superclass, rel):
     the related model, and adds behaviors specific to many-to-one relations.
     """
 
-    class RelatedManager(superclass):
+    class RelatedManager(superclass, AltersData):
         def __init__(self, instance):
             super().__init__()
 
@@ -946,7 +946,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
     the related model, and adds behaviors specific to many-to-many relations.
     """
 
-    class ManyRelatedManager(superclass):
+    class ManyRelatedManager(superclass, AltersData):
         def __init__(self, instance=None):
             super().__init__()
 

+ 6 - 2
django/db/models/query.py

@@ -27,7 +27,11 @@ from django.db.models.expressions import Case, F, Ref, Value, When
 from django.db.models.functions import Cast, Trunc
 from django.db.models.query_utils import FilteredRelation, Q
 from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
-from django.db.models.utils import create_namedtuple_class, resolve_callables
+from django.db.models.utils import (
+    AltersData,
+    create_namedtuple_class,
+    resolve_callables,
+)
 from django.utils import timezone
 from django.utils.deprecation import RemovedInDjango50Warning
 from django.utils.functional import cached_property, partition
@@ -284,7 +288,7 @@ class FlatValuesListIterable(BaseIterable):
             yield row[0]
 
 
-class QuerySet:
+class QuerySet(AltersData):
     """Represent a lazy database lookup for a set of objects."""
 
     def __init__(self, model=None, query=None, using=None, hints=None):

+ 17 - 0
django/db/models/utils.py

@@ -50,3 +50,20 @@ def create_namedtuple_class(*names):
         (namedtuple("Row", names),),
         {"__reduce__": __reduce__, "__slots__": ()},
     )
+
+
+class AltersData:
+    """
+    Make subclasses preserve the alters_data attribute on overridden methods.
+    """
+
+    def __init_subclass__(cls, **kwargs):
+        for fn_name, fn in vars(cls).items():
+            if callable(fn) and not hasattr(fn, "alters_data"):
+                for base in cls.__bases__:
+                    if base_fn := getattr(base, fn_name, None):
+                        if hasattr(base_fn, "alters_data"):
+                            fn.alters_data = base_fn.alters_data
+                        break
+
+        super().__init_subclass__(**kwargs)

+ 3 - 2
django/forms/models.py

@@ -10,6 +10,7 @@ from django.core.exceptions import (
     ImproperlyConfigured,
     ValidationError,
 )
+from django.db.models.utils import AltersData
 from django.forms.fields import ChoiceField, Field
 from django.forms.forms import BaseForm, DeclarativeFieldsMetaclass
 from django.forms.formsets import BaseFormSet, formset_factory
@@ -329,7 +330,7 @@ class ModelFormMetaclass(DeclarativeFieldsMetaclass):
         return new_class
 
 
-class BaseModelForm(BaseForm):
+class BaseModelForm(BaseForm, AltersData):
     def __init__(
         self,
         data=None,
@@ -644,7 +645,7 @@ def modelform_factory(
 # ModelFormSets ##############################################################
 
 
-class BaseModelFormSet(BaseFormSet):
+class BaseModelFormSet(BaseFormSet, AltersData):
     """
     A ``FormSet`` for editing a queryset and/or adding new objects to it.
     """

+ 63 - 0
tests/template_tests/test_callables.py

@@ -1,5 +1,6 @@
 from unittest import TestCase
 
+from django.db.models.utils import AltersData
 from django.template import Context, Engine
 
 
@@ -63,6 +64,68 @@ class CallableVariablesTests(TestCase):
         # template rendering.
         self.assertEqual(my_doodad.num_calls, 0)
 
+    def test_alters_data_propagation(self):
+        class GrandParentLeft(AltersData):
+            def my_method(self):
+                return 42
+
+            my_method.alters_data = True
+
+        class ParentLeft(GrandParentLeft):
+            def change_alters_data_method(self):
+                return 63
+
+            change_alters_data_method.alters_data = True
+
+            def sub_non_callable_method(self):
+                return 64
+
+            sub_non_callable_method.alters_data = True
+
+        class ParentRight(AltersData):
+            def other_method(self):
+                return 52
+
+            other_method.alters_data = True
+
+        class Child(ParentLeft, ParentRight):
+            def my_method(self):
+                return 101
+
+            def other_method(self):
+                return 102
+
+            def change_alters_data_method(self):
+                return 103
+
+            change_alters_data_method.alters_data = False
+
+            sub_non_callable_method = 104
+
+        class GrandChild(Child):
+            pass
+
+        child = Child()
+        self.assertIs(child.my_method.alters_data, True)
+        self.assertIs(child.other_method.alters_data, True)
+        self.assertIs(child.change_alters_data_method.alters_data, False)
+
+        grand_child = GrandChild()
+        self.assertIs(grand_child.my_method.alters_data, True)
+        self.assertIs(grand_child.other_method.alters_data, True)
+        self.assertIs(grand_child.change_alters_data_method.alters_data, False)
+
+        c = Context({"element": grand_child})
+
+        t = self.engine.from_string("{{ element.my_method }}")
+        self.assertEqual(t.render(c), "")
+        t = self.engine.from_string("{{ element.other_method }}")
+        self.assertEqual(t.render(c), "")
+        t = self.engine.from_string("{{ element.change_alters_data_method }}")
+        self.assertEqual(t.render(c), "103")
+        t = self.engine.from_string("{{ element.sub_non_callable_method }}")
+        self.assertEqual(t.render(c), "104")
+
     def test_do_not_call(self):
         class Doodad:
             do_not_call_in_templates = True