Browse Source

Fixed #24305 -- Allowed overriding fields on abstract models.

Fields inherited from abstract base classes may be overridden like
any other Python attribute. Inheriting from multiple models/classes
with the same attribute name will follow the MRO.
Aron Podrigal 9 years ago
parent
commit
85ef98dc6e

+ 1 - 0
AUTHORS

@@ -69,6 +69,7 @@ answer newbie questions, and generally made Django that much better:
     Aram Dulyan
     arien <regexbot@gmail.com>
     Armin Ronacher
+    Aron Podrigal <aronp@guaranteedplus.com>
     Artem Gnilov <boobsd@gmail.com>
     Arthur <avandorp@gmail.com>
     Arthur Koziel <http://arthurkoziel.com>

+ 57 - 28
django/db/models/base.py

@@ -212,26 +212,35 @@ class ModelBase(type):
                 if isinstance(field, OneToOneField):
                     related = resolve_relation(new_class, field.remote_field.model)
                     parent_links[make_model_tuple(related)] = field
+
+        # Track fields inherited from base models.
+        inherited_attributes = set()
         # Do the appropriate setup for any model parents.
-        for base in parents:
+        for base in new_class.mro():
             original_base = base
-            if not hasattr(base, '_meta'):
+            if base not in parents or not hasattr(base, '_meta'):
                 # Things without _meta aren't functional models, so they're
                 # uninteresting parents.
+                inherited_attributes |= set(base.__dict__.keys())
                 continue
 
             parent_fields = base._meta.local_fields + base._meta.local_many_to_many
-            # Check for clashes between locally declared fields and those
-            # on the base classes (we cannot handle shadowed fields at the
-            # moment).
-            for field in parent_fields:
-                if field.name in field_names:
-                    raise FieldError(
-                        'Local field %r in class %r clashes '
-                        'with field of similar name from '
-                        'base class %r' % (field.name, name, base.__name__)
-                    )
             if not base._meta.abstract:
+                # Check for clashes between locally declared fields and those
+                # on the base classes.
+                for field in parent_fields:
+                    if field.name in field_names:
+                        raise FieldError(
+                            'Local field %r in class %r clashes with field of '
+                            'the same name from base class %r.' % (
+                                field.name,
+                                name,
+                                base.__name__,
+                            )
+                        )
+                    else:
+                        inherited_attributes.add(field.name)
+
                 # Concrete classes...
                 base = base._meta.concrete_model
                 base_key = make_model_tuple(base)
@@ -246,6 +255,18 @@ class ModelBase(type):
                         auto_created=True,
                         parent_link=True,
                     )
+
+                    if attr_name in field_names:
+                        raise FieldError(
+                            "Auto-generated field '%s' in class %r for "
+                            "parent_link to base class %r clashes with "
+                            "declared field of the same name." % (
+                                attr_name,
+                                name,
+                                base.__name__,
+                            )
+                        )
+
                     # Only add the ptr field if it's not already present;
                     # e.g. migrations will already have it specified
                     if not hasattr(new_class, attr_name):
@@ -256,16 +277,19 @@ class ModelBase(type):
             else:
                 base_parents = base._meta.parents.copy()
 
-                # .. and abstract ones.
+                # Add fields from abstract base class if it wasn't overridden.
                 for field in parent_fields:
-                    new_field = copy.deepcopy(field)
-                    new_class.add_to_class(field.name, new_field)
-                    # Replace parent links defined on this base by the new
-                    # field as it will be appropriately resolved if required.
-                    if field.one_to_one:
-                        for parent, parent_link in base_parents.items():
-                            if field == parent_link:
-                                base_parents[parent] = new_field
+                    if (field.name not in field_names and
+                            field.name not in new_class.__dict__ and
+                            field.name not in inherited_attributes):
+                        new_field = copy.deepcopy(field)
+                        new_class.add_to_class(field.name, new_field)
+                        # Replace parent links defined on this base by the new
+                        # field. It will be appropriately resolved if required.
+                        if field.one_to_one:
+                            for parent, parent_link in base_parents.items():
+                                if field == parent_link:
+                                    base_parents[parent] = new_field
 
                 # Pass any non-abstract parent classes onto child.
                 new_class._meta.parents.update(base_parents)
@@ -281,13 +305,18 @@ class ModelBase(type):
             # Inherit private fields (like GenericForeignKey) from the parent
             # class
             for field in base._meta.private_fields:
-                if base._meta.abstract and field.name in field_names:
-                    raise FieldError(
-                        'Local field %r in class %r clashes '
-                        'with field of similar name from '
-                        'abstract base class %r' % (field.name, name, base.__name__)
-                    )
-                new_class.add_to_class(field.name, copy.deepcopy(field))
+                if field.name in field_names:
+                    if not base._meta.abstract:
+                        raise FieldError(
+                            'Local field %r in class %r clashes with field of '
+                            'the same name from base class %r.' % (
+                                field.name,
+                                name,
+                                base.__name__,
+                            )
+                        )
+                else:
+                    new_class.add_to_class(field.name, copy.deepcopy(field))
 
         if abstract:
             # Abstract base models can't be instantiated and don't appear in

+ 2 - 0
docs/releases/1.10.txt

@@ -400,6 +400,8 @@ Models
   app label and class interpolation using the ``'%(app_label)s'`` and
   ``'%(class)s'`` strings.
 
+* Allowed overriding model fields inherited from abstract base classes.
+
 * The :func:`~django.db.models.prefetch_related_objects` function is now a
   public API.
 

+ 29 - 5
docs/topics/db/models.txt

@@ -1376,11 +1376,35 @@ Field name "hiding" is not permitted
 -------------------------------------
 
 In normal Python class inheritance, it is permissible for a child class to
-override any attribute from the parent class. In Django, this is not permitted
-for attributes that are :class:`~django.db.models.Field` instances (at
-least, not at the moment). If a base class has a field called ``author``, you
-cannot create another model field called ``author`` in any class that inherits
-from that base class.
+override any attribute from the parent class. In Django, this isn't usually
+permitted for model fields. If a non-abstract model base class has a field
+called ``author``, you can't create another model field or define
+an attribute called ``author`` in any class that inherits from that base class.
+
+This restriction doesn't apply to model fields inherited from an abstract
+model. Such fields may be overridden with another field or value, or be removed
+by setting ``field_name = None``.
+
+.. versionchanged:: 1.10
+
+    The ability to override abstract fields was added.
+
+.. warning::
+
+    Model managers are inherited from abstract base classes. Overriding an
+    inherited field which is referenced by an inherited
+    :class:`~django.db.models.Manager` may cause subtle bugs. See :ref:`custom
+    managers and model inheritance <custom-managers-and-inheritance>`.
+
+.. note::
+
+    Some fields define extra attributes on the model, e.g. a
+    :class:`~django.db.models.ForeignKey` defines an extra attribute with
+    ``_id`` appended to the field name, as well as ``related_name`` and
+    ``related_query_name`` on the foreign model.
+
+    These extra attributes cannot be overridden unless the field that defines
+    it is changed or removed so that it no longer defines the extra attribute.
 
 Overriding fields in a parent model leads to difficulties in areas such as
 initializing new instances (specifying which field is being initialized in

+ 354 - 0
tests/model_inheritance/test_abstract_inheritance.py

@@ -0,0 +1,354 @@
+from __future__ import unicode_literals
+
+from django.contrib.contenttypes.fields import (
+    GenericForeignKey, GenericRelation,
+)
+from django.contrib.contenttypes.models import ContentType
+from django.core.checks import Error
+from django.core.exceptions import FieldDoesNotExist, FieldError
+from django.db import models
+from django.test import TestCase
+from django.test.utils import isolate_apps
+
+
+@isolate_apps('model_inheritance')
+class AbstractInheritanceTests(TestCase):
+    def test_single_parent(self):
+        class AbstractBase(models.Model):
+            name = models.CharField(max_length=30)
+
+            class Meta:
+                abstract = True
+
+        class AbstractDescendant(AbstractBase):
+            name = models.CharField(max_length=50)
+
+            class Meta:
+                abstract = True
+
+        class DerivedChild(AbstractBase):
+            name = models.CharField(max_length=50)
+
+        class DerivedGrandChild(AbstractDescendant):
+            pass
+
+        self.assertEqual(AbstractDescendant._meta.get_field('name').max_length, 50)
+        self.assertEqual(DerivedChild._meta.get_field('name').max_length, 50)
+        self.assertEqual(DerivedGrandChild._meta.get_field('name').max_length, 50)
+
+    def test_multiple_parents_mro(self):
+        class AbstractBaseOne(models.Model):
+            class Meta:
+                abstract = True
+
+        class AbstractBaseTwo(models.Model):
+            name = models.CharField(max_length=30)
+
+            class Meta:
+                abstract = True
+
+        class DescendantOne(AbstractBaseOne, AbstractBaseTwo):
+            class Meta:
+                abstract = True
+
+        class DescendantTwo(AbstractBaseOne, AbstractBaseTwo):
+            name = models.CharField(max_length=50)
+
+            class Meta:
+                abstract = True
+
+        class Derived(DescendantOne, DescendantTwo):
+            pass
+
+        self.assertEqual(DescendantOne._meta.get_field('name').max_length, 30)
+        self.assertEqual(DescendantTwo._meta.get_field('name').max_length, 50)
+        self.assertEqual(Derived._meta.get_field('name').max_length, 50)
+
+    def test_multiple_inheritance_cannot_shadow_concrete_inherited_field(self):
+        class ConcreteParent(models.Model):
+            name = models.CharField(max_length=255)
+
+        class AbstractParent(models.Model):
+            name = models.IntegerField()
+
+            class Meta:
+                abstract = True
+
+        class FirstChild(ConcreteParent, AbstractParent):
+            pass
+
+        class AnotherChild(AbstractParent, ConcreteParent):
+            pass
+
+        self.assertIsInstance(FirstChild._meta.get_field('name'), models.CharField)
+        self.assertEqual(
+            AnotherChild.check(),
+            [Error(
+                "The field 'name' clashes with the field 'name' "
+                "from model 'model_inheritance.concreteparent'.",
+                obj=AnotherChild._meta.get_field('name'),
+                id="models.E006",
+            )]
+        )
+
+    def test_virtual_field(self):
+        class RelationModel(models.Model):
+            content_type = models.ForeignKey(ContentType, models.CASCADE)
+            object_id = models.PositiveIntegerField()
+            content_object = GenericForeignKey('content_type', 'object_id')
+
+        class RelatedModelAbstract(models.Model):
+            field = GenericRelation(RelationModel)
+
+            class Meta:
+                abstract = True
+
+        class ModelAbstract(models.Model):
+            field = models.CharField(max_length=100)
+
+            class Meta:
+                abstract = True
+
+        class OverrideRelatedModelAbstract(RelatedModelAbstract):
+            field = models.CharField(max_length=100)
+
+        class ExtendModelAbstract(ModelAbstract):
+            field = GenericRelation(RelationModel)
+
+        self.assertIsInstance(OverrideRelatedModelAbstract._meta.get_field('field'), models.CharField)
+        self.assertIsInstance(ExtendModelAbstract._meta.get_field('field'), GenericRelation)
+
+    def test_cannot_override_indirect_abstract_field(self):
+        class AbstractBase(models.Model):
+            name = models.CharField(max_length=30)
+
+            class Meta:
+                abstract = True
+
+        class ConcreteDescendant(AbstractBase):
+            pass
+
+        msg = (
+            "Local field 'name' in class 'Descendant' clashes with field of "
+            "the same name from base class 'ConcreteDescendant'."
+        )
+        with self.assertRaisesMessage(FieldError, msg):
+            class Descendant(ConcreteDescendant):
+                name = models.IntegerField()
+
+    def test_override_field_with_attr(self):
+        class AbstractBase(models.Model):
+            first_name = models.CharField(max_length=50)
+            last_name = models.CharField(max_length=50)
+            middle_name = models.CharField(max_length=30)
+            full_name = models.CharField(max_length=150)
+
+            class Meta:
+                abstract = True
+
+        class Descendant(AbstractBase):
+            middle_name = None
+
+            def full_name(self):
+                return self.first_name + self.last_name
+
+        with self.assertRaises(FieldDoesNotExist):
+            Descendant._meta.get_field('middle_name')
+
+        with self.assertRaises(FieldDoesNotExist):
+            Descendant._meta.get_field('full_name')
+
+    def test_overriding_field_removed_by_concrete_model(self):
+        class AbstractModel(models.Model):
+            foo = models.CharField(max_length=30)
+
+            class Meta:
+                abstract = True
+
+        class RemovedAbstractModelField(AbstractModel):
+            foo = None
+
+        class OverrideRemovedFieldByConcreteModel(RemovedAbstractModelField):
+            foo = models.CharField(max_length=50)
+
+        self.assertEqual(OverrideRemovedFieldByConcreteModel._meta.get_field('foo').max_length, 50)
+
+    def test_shadowed_fkey_id(self):
+        class Foo(models.Model):
+            pass
+
+        class AbstractBase(models.Model):
+            foo = models.ForeignKey(Foo, models.CASCADE)
+
+            class Meta:
+                abstract = True
+
+        class Descendant(AbstractBase):
+            foo_id = models.IntegerField()
+
+        self.assertEqual(
+            Descendant.check(),
+            [Error(
+                "The field 'foo_id' clashes with the field 'foo' "
+                "from model 'model_inheritance.descendant'.",
+                obj=Descendant._meta.get_field('foo_id'),
+                id='models.E006',
+            )]
+        )
+
+    def test_shadow_related_name_when_set_to_none(self):
+        class AbstractBase(models.Model):
+            bar = models.IntegerField()
+
+            class Meta:
+                abstract = True
+
+        class Foo(AbstractBase):
+            bar = None
+            foo = models.IntegerField()
+
+        class Bar(models.Model):
+            bar = models.ForeignKey(Foo, models.CASCADE, related_name='bar')
+
+        self.assertEqual(Bar.check(), [])
+
+    def test_reverse_foreign_key(self):
+        class AbstractBase(models.Model):
+            foo = models.CharField(max_length=100)
+
+            class Meta:
+                abstract = True
+
+        class Descendant(AbstractBase):
+            pass
+
+        class Foo(models.Model):
+            foo = models.ForeignKey(Descendant, models.CASCADE, related_name='foo')
+
+        self.assertEqual(
+            Foo._meta.get_field('foo').check(),
+            [
+                Error(
+                    "Reverse accessor for 'Foo.foo' clashes with field name 'Descendant.foo'.",
+                    hint=(
+                        "Rename field 'Descendant.foo', or add/change a related_name "
+                        "argument to the definition for field 'Foo.foo'."
+                    ),
+                    obj=Foo._meta.get_field('foo'),
+                    id='fields.E302',
+                ),
+                Error(
+                    "Reverse query name for 'Foo.foo' clashes with field name 'Descendant.foo'.",
+                    hint=(
+                        "Rename field 'Descendant.foo', or add/change a related_name "
+                        "argument to the definition for field 'Foo.foo'."
+                    ),
+                    obj=Foo._meta.get_field('foo'),
+                    id='fields.E303',
+                ),
+            ]
+        )
+
+    def test_multi_inheritance_field_clashes(self):
+        class AbstractBase(models.Model):
+            name = models.CharField(max_length=30)
+
+            class Meta:
+                abstract = True
+
+        class ConcreteBase(AbstractBase):
+            pass
+
+        class AbstractDescendant(ConcreteBase):
+            class Meta:
+                abstract = True
+
+        class ConcreteDescendant(AbstractDescendant):
+            name = models.CharField(max_length=100)
+
+        self.assertEqual(
+            ConcreteDescendant.check(),
+            [Error(
+                "The field 'name' clashes with the field 'name' from "
+                "model 'model_inheritance.concretebase'.",
+                obj=ConcreteDescendant._meta.get_field('name'),
+                id="models.E006",
+            )]
+        )
+
+    def test_override_one2one_relation_auto_field_clashes(self):
+        class ConcreteParent(models.Model):
+            name = models.CharField(max_length=255)
+
+        class AbstractParent(models.Model):
+            name = models.IntegerField()
+
+            class Meta:
+                abstract = True
+
+        msg = (
+            "Auto-generated field 'concreteparent_ptr' in class 'Descendant' "
+            "for parent_link to base class 'ConcreteParent' clashes with "
+            "declared field of the same name."
+        )
+        with self.assertRaisesMessage(FieldError, msg):
+            class Descendant(ConcreteParent, AbstractParent):
+                concreteparent_ptr = models.CharField(max_length=30)
+
+    def test_abstract_model_with_regular_python_mixin_mro(self):
+        class AbstractModel(models.Model):
+            name = models.CharField(max_length=255)
+            age = models.IntegerField()
+
+            class Meta:
+                abstract = True
+
+        class Mixin(object):
+            age = None
+
+        class Mixin2(object):
+            age = 2
+
+        class DescendantMixin(Mixin):
+            pass
+
+        class ConcreteModel(models.Model):
+            foo = models.IntegerField()
+
+        class ConcreteModel2(ConcreteModel):
+            age = models.SmallIntegerField()
+
+        def fields(model):
+            if not hasattr(model, '_meta'):
+                return list()
+            return list((f.name, f.__class__) for f in model._meta.get_fields())
+
+        model_dict = {'__module__': 'model_inheritance'}
+        model1 = type(str('Model1'), (AbstractModel, Mixin), model_dict.copy())
+        model2 = type(str('Model2'), (Mixin2, AbstractModel), model_dict.copy())
+        model3 = type(str('Model3'), (DescendantMixin, AbstractModel), model_dict.copy())
+        model4 = type(str('Model4'), (Mixin2, Mixin, AbstractModel), model_dict.copy())
+        model5 = type(str('Model5'), (Mixin2, ConcreteModel2, Mixin, AbstractModel), model_dict.copy())
+
+        self.assertEqual(
+            fields(model1),
+            [('id', models.AutoField), ('name', models.CharField), ('age', models.IntegerField)]
+        )
+
+        self.assertEqual(fields(model2), [('id', models.AutoField), ('name', models.CharField)])
+        self.assertEqual(getattr(model2, 'age'), 2)
+
+        self.assertEqual(fields(model3), [('id', models.AutoField), ('name', models.CharField)])
+
+        self.assertEqual(fields(model4), [('id', models.AutoField), ('name', models.CharField)])
+        self.assertEqual(getattr(model4, 'age'), 2)
+
+        self.assertEqual(
+            fields(model5),
+            [
+                ('id', models.AutoField), ('foo', models.IntegerField),
+                ('concretemodel_ptr', models.OneToOneField),
+                ('age', models.SmallIntegerField), ('concretemodel2_ptr', models.OneToOneField),
+                ('name', models.CharField),
+            ]
+        )