Browse Source

Fixed #27272 -- Added an on_delete RESTRICT handler to allow cascading deletions while protecting direct ones.

Daniel Izquierdo 8 years ago
parent
commit
89abecc75d

+ 2 - 0
django/contrib/admin/utils.py

@@ -181,6 +181,8 @@ class NestedObjects(Collector):
             return super().collect(objs, source_attr=source_attr, **kwargs)
         except models.ProtectedError as e:
             self.protected.update(e.protected_objects)
+        except models.RestrictedError as e:
+            self.protected.update(e.restricted_objects)
 
     def related_objects(self, related_model, related_fields, objs):
         qs = super().related_objects(related_model, related_fields, objs)

+ 4 - 3
django/db/models/__init__.py

@@ -5,7 +5,8 @@ from django.db.models.aggregates import __all__ as aggregates_all
 from django.db.models.constraints import *  # NOQA
 from django.db.models.constraints import __all__ as constraints_all
 from django.db.models.deletion import (
-    CASCADE, DO_NOTHING, PROTECT, SET, SET_DEFAULT, SET_NULL, ProtectedError,
+    CASCADE, DO_NOTHING, PROTECT, RESTRICT, SET, SET_DEFAULT, SET_NULL,
+    ProtectedError, RestrictedError,
 )
 from django.db.models.enums import *  # NOQA
 from django.db.models.enums import __all__ as enums_all
@@ -37,8 +38,8 @@ from django.db.models.fields.related import (  # isort:skip
 __all__ = aggregates_all + constraints_all + enums_all + fields_all + indexes_all
 __all__ += [
     'ObjectDoesNotExist', 'signals',
-    'CASCADE', 'DO_NOTHING', 'PROTECT', 'SET', 'SET_DEFAULT', 'SET_NULL',
-    'ProtectedError',
+    'CASCADE', 'DO_NOTHING', 'PROTECT', 'RESTRICT', 'SET', 'SET_DEFAULT',
+    'SET_NULL', 'ProtectedError', 'RestrictedError',
     'Case', 'Exists', 'Expression', 'ExpressionList', 'ExpressionWrapper', 'F',
     'Func', 'OuterRef', 'RowRange', 'Subquery', 'Value', 'ValueRange', 'When',
     'Window', 'WindowFrame',

+ 69 - 5
django/db/models/deletion.py

@@ -14,9 +14,17 @@ class ProtectedError(IntegrityError):
         super().__init__(msg, protected_objects)
 
 
+class RestrictedError(IntegrityError):
+    def __init__(self, msg, restricted_objects):
+        self.restricted_objects = restricted_objects
+        super().__init__(msg, restricted_objects)
+
+
 def CASCADE(collector, field, sub_objs, using):
-    collector.collect(sub_objs, source=field.remote_field.model,
-                      source_attr=field.name, nullable=field.null)
+    collector.collect(
+        sub_objs, source=field.remote_field.model, source_attr=field.name,
+        nullable=field.null, fail_on_restricted=False,
+    )
     if field.null and not connections[using].features.can_defer_constraint_checks:
         collector.add_field_update(field, None, sub_objs)
 
@@ -31,6 +39,11 @@ def PROTECT(collector, field, sub_objs, using):
     )
 
 
+def RESTRICT(collector, field, sub_objs, using):
+    collector.add_restricted_objects(field, sub_objs)
+    collector.add_dependency(field.remote_field.model, field.model)
+
+
 def SET(value):
     if callable(value):
         def set_on_delete(collector, field, sub_objs, using):
@@ -70,6 +83,8 @@ class Collector:
         self.data = defaultdict(set)
         # {model: {(field, value): {instances}}}
         self.field_updates = defaultdict(partial(defaultdict, set))
+        # {model: {field: {instances}}}
+        self.restricted_objects = defaultdict(partial(defaultdict, set))
         # fast_deletes is a list of queryset-likes that can be deleted without
         # fetching the objects into memory.
         self.fast_deletes = []
@@ -121,6 +136,26 @@ class Collector:
         model = objs[0].__class__
         self.field_updates[model][field, value].update(objs)
 
+    def add_restricted_objects(self, field, objs):
+        if objs:
+            model = objs[0].__class__
+            self.restricted_objects[model][field].update(objs)
+
+    def clear_restricted_objects_from_set(self, model, objs):
+        if model in self.restricted_objects:
+            self.restricted_objects[model] = {
+                field: items - objs
+                for field, items in self.restricted_objects[model].items()
+            }
+
+    def clear_restricted_objects_from_queryset(self, model, qs):
+        if model in self.restricted_objects:
+            objs = set(qs.filter(pk__in=[
+                obj.pk
+                for objs in self.restricted_objects[model].values() for obj in objs
+            ]))
+            self.clear_restricted_objects_from_set(model, objs)
+
     def _has_signal_listeners(self, model):
         return (
             signals.pre_delete.has_listeners(model) or
@@ -177,7 +212,8 @@ class Collector:
             return [objs]
 
     def collect(self, objs, source=None, nullable=False, collect_related=True,
-                source_attr=None, reverse_dependency=False, keep_parents=False):
+                source_attr=None, reverse_dependency=False, keep_parents=False,
+                fail_on_restricted=True):
         """
         Add 'objs' to the collection of objects to be deleted as well as all
         parent instances.  'objs' must be a homogeneous iterable collection of
@@ -194,6 +230,12 @@ class Collector:
         direction of an FK rather than the reverse direction.)
 
         If 'keep_parents' is True, data of parent model's will be not deleted.
+
+        If 'fail_on_restricted' is False, error won't be raised even if it's
+        prohibited to delete such objects due to RESTRICT, that defers
+        restricted object checking in recursive calls where the top-level call
+        may need to collect more objects to determine whether restricted ones
+        can be deleted.
         """
         if self.can_fast_delete(objs):
             self.fast_deletes.append(objs)
@@ -215,7 +257,8 @@ class Collector:
                     self.collect(parent_objs, source=model,
                                  source_attr=ptr.remote_field.related_name,
                                  collect_related=False,
-                                 reverse_dependency=True)
+                                 reverse_dependency=True,
+                                 fail_on_restricted=False)
         if not collect_related:
             return
 
@@ -259,7 +302,28 @@ class Collector:
             if hasattr(field, 'bulk_related_objects'):
                 # It's something like generic foreign key.
                 sub_objs = field.bulk_related_objects(new_objs, self.using)
-                self.collect(sub_objs, source=model, nullable=True)
+                self.collect(sub_objs, source=model, nullable=True, fail_on_restricted=False)
+
+        if fail_on_restricted:
+            # Raise an error if collected restricted objects (RESTRICT) aren't
+            # candidates for deletion also collected via CASCADE.
+            for model, instances in self.data.items():
+                self.clear_restricted_objects_from_set(model, instances)
+            for qs in self.fast_deletes:
+                self.clear_restricted_objects_from_queryset(qs.model, qs)
+            for model, fields in self.restricted_objects.items():
+                for field, objs in fields.items():
+                    for obj in objs:
+                        raise RestrictedError(
+                            "Cannot delete some instances of model '%s' "
+                            "because they are referenced through a restricted "
+                            "foreign key: '%s.%s'." % (
+                                field.remote_field.model.__name__,
+                                obj.__class__.__name__,
+                                field.name,
+                            ),
+                            objs,
+                        )
 
     def related_objects(self, related_model, related_fields, objs):
         """

+ 6 - 0
docs/ref/exceptions.txt

@@ -255,6 +255,12 @@ Raised to prevent deletion of referenced objects when using
 :attr:`django.db.models.PROTECT`. :exc:`models.ProtectedError` is a subclass
 of :exc:`IntegrityError`.
 
+.. exception:: models.RestrictedError
+
+Raised to prevent deletion of referenced objects when using
+:attr:`django.db.models.RESTRICT`. :exc:`models.RestrictedError` is a subclass
+of :exc:`IntegrityError`.
+
 .. currentmodule:: django.http
 
 Http Exceptions

+ 40 - 0
docs/ref/models/fields.txt

@@ -1470,6 +1470,46 @@ The possible values for :attr:`~ForeignKey.on_delete` are found in
     :exc:`~django.db.models.ProtectedError`, a subclass of
     :exc:`django.db.IntegrityError`.
 
+* .. attribute:: RESTRICT
+
+    .. versionadded:: 3.1
+
+    Prevent deletion of the referenced object by raising
+    :exc:`~django.db.models.RestrictedError` (a subclass of
+    :exc:`django.db.IntegrityError`). Unlike :attr:`PROTECT`, deletion of the
+    referenced object is allowed if it also references a different object
+    that is being deleted in the same operation, but via a :attr:`CASCADE`
+    relationship.
+
+    Consider this set of models::
+
+        class Artist(models.Model):
+            name = models.CharField(max_length=10)
+
+        class Album(models.Model):
+            artist = models.ForeignKey(Artist, on_delete=models.CASCADE)
+
+        class Song(models.Model):
+            artist = models.ForeignKey(Artist, on_delete=models.CASCADE)
+            album = models.ForeignKey(Album, on_delete=models.RESTRICT)
+
+    ``Artist`` can be deleted even if that implies deleting an ``Album``
+    which is referenced by a ``Song``, because ``Song`` also references
+    ``Artist`` itself through a cascading relationship. For example::
+
+        >>> artist_one = Artist.objects.create(name='artist one')
+        >>> artist_two = Artist.objects.create(name='artist two')
+        >>> album_one = Album.objects.create(artist=artist_one)
+        >>> album_two = Album.objects.create(artist=artist_two)
+        >>> song_one = Song.objects.create(artist=artist_one, album=album_one)
+        >>> song_two = Song.objects.create(artist=artist_one, album=album_two)
+        >>> album_one.delete()
+        # Raises RestrictedError.
+        >>> artist_two.delete()
+        # Raises RestrictedError.
+        >>> artist_one.delete()
+        (4, {'Song': 2, 'Album': 1, 'Artist': 1})
+
 * .. attribute:: SET_NULL
 
     Set the :class:`ForeignKey` null; this is only possible if

+ 5 - 0
docs/releases/3.1.txt

@@ -199,6 +199,11 @@ Models
   values under a certain (database-dependent) limit. Values from ``0`` to
   ``9223372036854775807`` are safe in all databases supported by Django.
 
+* The new :class:`~django.db.models.RESTRICT` option for
+  :attr:`~django.db.models.ForeignKey.on_delete` argument of ``ForeignKey`` and
+  ``OneToOneField`` emulates the behavior of the SQL constraint ``ON DELETE
+  RESTRICT``.
+
 Pagination
 ~~~~~~~~~~
 

+ 3 - 2
tests/admin_views/admin.py

@@ -41,8 +41,8 @@ from .models import (
     ReferencedByGenRel, ReferencedByInline, ReferencedByParent,
     RelatedPrepopulated, RelatedWithUUIDPKModel, Report, Reservation,
     Restaurant, RowLevelChangePermissionModel, Section, ShortMessage, Simple,
-    Sketch, State, Story, StumpJoke, Subscriber, SuperVillain, Telegram, Thing,
-    Topping, UnchangeableObject, UndeletableObject, UnorderedObject,
+    Sketch, Song, State, Story, StumpJoke, Subscriber, SuperVillain, Telegram,
+    Thing, Topping, UnchangeableObject, UndeletableObject, UnorderedObject,
     UserMessenger, UserProxy, Villain, Vodcast, Whatsit, Widget, Worker,
     WorkHour,
 )
@@ -1069,6 +1069,7 @@ site.register(ReadOnlyPizza, ReadOnlyPizzaAdmin)
 site.register(ReadablePizza)
 site.register(Topping, ToppingAdmin)
 site.register(Album, AlbumAdmin)
+site.register(Song)
 site.register(Question, QuestionAdmin)
 site.register(Answer, AnswerAdmin, date_hierarchy='question__posted')
 site.register(Answer2, date_hierarchy='question__expires')

+ 8 - 0
tests/admin_views/models.py

@@ -604,6 +604,14 @@ class Album(models.Model):
     title = models.CharField(max_length=30)
 
 
+class Song(models.Model):
+    name = models.CharField(max_length=20)
+    album = models.ForeignKey(Album, on_delete=models.RESTRICT)
+
+    def __str__(self):
+        return self.name
+
+
 class Employee(Person):
     code = models.CharField(max_length=20)
 

+ 29 - 2
tests/admin_views/tests.py

@@ -38,7 +38,7 @@ from . import customadmin
 from .admin import CityAdmin, site, site2
 from .models import (
     Actor, AdminOrderedAdminMethod, AdminOrderedCallable, AdminOrderedField,
-    AdminOrderedModelMethod, Answer, Answer2, Article, BarAccount, Book,
+    AdminOrderedModelMethod, Album, Answer, Answer2, Article, BarAccount, Book,
     Bookmark, Category, Chapter, ChapterXtra1, ChapterXtra2, Character, Child,
     Choice, City, Collector, Color, ComplexSortedPerson, CoverLetter,
     CustomArticle, CyclicOne, CyclicTwo, DooHickey, Employee, EmptyModel,
@@ -50,7 +50,7 @@ from .models import (
     PrePopulatedPost, Promo, Question, ReadablePizza, ReadOnlyPizza,
     Recommendation, Recommender, RelatedPrepopulated, RelatedWithUUIDPKModel,
     Report, Restaurant, RowLevelChangePermissionModel, SecretHideout, Section,
-    ShortMessage, Simple, State, Story, SuperSecretHideout, SuperVillain,
+    ShortMessage, Simple, Song, State, Story, SuperSecretHideout, SuperVillain,
     Telegram, TitleTranslation, Topping, UnchangeableObject, UndeletableObject,
     UnorderedObject, UserProxy, Villain, Vodcast, Whatsit, Widget, Worker,
     WorkHour,
@@ -2603,6 +2603,33 @@ class AdminViewDeletedObjectsTest(TestCase):
         self.assertEqual(Question.objects.count(), 1)
         self.assertContains(response, "would require deleting the following protected related objects")
 
+    def test_restricted(self):
+        album = Album.objects.create(title='Amaryllis')
+        song = Song.objects.create(album=album, name='Unity')
+        response = self.client.get(reverse('admin:admin_views_album_delete', args=(album.pk,)))
+        self.assertContains(
+            response,
+            'would require deleting the following protected related objects',
+        )
+        self.assertContains(
+            response,
+            '<li>Song: <a href="%s">Unity</a></li>'
+            % reverse('admin:admin_views_song_change', args=(song.pk,))
+        )
+
+    def test_post_delete_restricted(self):
+        album = Album.objects.create(title='Amaryllis')
+        Song.objects.create(album=album, name='Unity')
+        response = self.client.post(
+            reverse('admin:admin_views_album_delete', args=(album.pk,)),
+            {'post': 'yes'},
+        )
+        self.assertEqual(Album.objects.count(), 1)
+        self.assertContains(
+            response,
+            'would require deleting the following protected related objects',
+        )
+
     def test_not_registered(self):
         should_contain = """<li>Secret hideout: underground bunker"""
         response = self.client.get(reverse('admin:admin_views_villain_delete', args=(self.v1.pk,)))

+ 55 - 2
tests/delete/models.py

@@ -1,8 +1,17 @@
+from django.contrib.contenttypes.fields import (
+    GenericForeignKey, GenericRelation,
+)
+from django.contrib.contenttypes.models import ContentType
 from django.db import models
 
 
+class P(models.Model):
+    pass
+
+
 class R(models.Model):
     is_default = models.BooleanField(default=False)
+    p = models.ForeignKey(P, models.CASCADE, null=True)
 
     def __str__(self):
         return "%s" % self.pk
@@ -46,10 +55,12 @@ class A(models.Model):
     )
     cascade = models.ForeignKey(R, models.CASCADE, related_name='cascade_set')
     cascade_nullable = models.ForeignKey(R, models.CASCADE, null=True, related_name='cascade_nullable_set')
-    protect = models.ForeignKey(R, models.PROTECT, null=True)
+    protect = models.ForeignKey(R, models.PROTECT, null=True, related_name='protect_set')
+    restrict = models.ForeignKey(R, models.RESTRICT, null=True, related_name='restrict_set')
     donothing = models.ForeignKey(R, models.DO_NOTHING, null=True, related_name='donothing_set')
     child = models.ForeignKey(RChild, models.CASCADE, related_name="child")
     child_setnull = models.ForeignKey(RChild, models.SET_NULL, null=True, related_name="child_setnull")
+    cascade_p = models.ForeignKey(P, models.CASCADE, related_name='cascade_p_set', null=True)
 
     # A OneToOneField is just a ForeignKey unique=True, so we don't duplicate
     # all the tests; just one smoke test to ensure on_delete works for it as
@@ -61,7 +72,7 @@ def create_a(name):
     a = A(name=name)
     for name in ('auto', 'auto_nullable', 'setvalue', 'setnull', 'setdefault',
                  'setdefault_none', 'cascade', 'cascade_nullable', 'protect',
-                 'donothing', 'o2o_setnull'):
+                 'restrict', 'donothing', 'o2o_setnull'):
         r = R.objects.create()
         setattr(a, name, r)
     a.child = RChild.objects.create()
@@ -147,3 +158,45 @@ class SecondReferrer(models.Model):
     other_referrer = models.ForeignKey(
         Referrer, models.CASCADE, to_field='unique_field', related_name='+'
     )
+
+
+class DeleteTop(models.Model):
+    b1 = GenericRelation('GenericB1')
+    b2 = GenericRelation('GenericB2')
+
+
+class B1(models.Model):
+    delete_top = models.ForeignKey(DeleteTop, models.CASCADE)
+
+
+class B2(models.Model):
+    delete_top = models.ForeignKey(DeleteTop, models.CASCADE)
+
+
+class DeleteBottom(models.Model):
+    b1 = models.ForeignKey(B1, models.RESTRICT)
+    b2 = models.ForeignKey(B2, models.CASCADE)
+
+
+class GenericB1(models.Model):
+    content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
+    object_id = models.PositiveIntegerField()
+    generic_delete_top = GenericForeignKey('content_type', 'object_id')
+
+
+class GenericB2(models.Model):
+    content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
+    object_id = models.PositiveIntegerField()
+    generic_delete_top = GenericForeignKey('content_type', 'object_id')
+    generic_delete_bottom = GenericRelation('GenericDeleteBottom')
+
+
+class GenericDeleteBottom(models.Model):
+    generic_b1 = models.ForeignKey(GenericB1, models.RESTRICT)
+    content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
+    object_id = models.PositiveIntegerField()
+    generic_b2 = GenericForeignKey()
+
+
+class GenericDeleteBottomParent(models.Model):
+    generic_delete_bottom = models.ForeignKey(GenericDeleteBottom, on_delete=models.CASCADE)

+ 85 - 3
tests/delete/tests.py

@@ -1,13 +1,14 @@
 from math import ceil
 
 from django.db import IntegrityError, connection, models
-from django.db.models.deletion import Collector
+from django.db.models.deletion import Collector, RestrictedError
 from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE
 from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
 
 from .models import (
-    MR, A, Avatar, Base, Child, HiddenUser, HiddenUserProfile, M, M2MFrom,
-    M2MTo, MRNull, Origin, Parent, R, RChild, RChildChild, Referrer, S, T,
+    B1, B2, MR, A, Avatar, Base, Child, DeleteBottom, DeleteTop, GenericB1,
+    GenericB2, GenericDeleteBottom, HiddenUser, HiddenUserProfile, M, M2MFrom,
+    M2MTo, MRNull, Origin, P, Parent, R, RChild, RChildChild, Referrer, S, T,
     User, create_a, get_default_r,
 )
 
@@ -146,6 +147,87 @@ class OnDeleteTests(TestCase):
         a = A.objects.get(pk=a.pk)
         self.assertIsNone(a.o2o_setnull)
 
+    def test_restrict(self):
+        a = create_a('restrict')
+        msg = (
+            "Cannot delete some instances of model 'R' because they are "
+            "referenced through a restricted foreign key: 'A.restrict'."
+        )
+        with self.assertRaisesMessage(RestrictedError, msg):
+            a.restrict.delete()
+
+    def test_restrict_path_cascade_indirect(self):
+        a = create_a('restrict')
+        a.restrict.p = P.objects.create()
+        a.restrict.save()
+        msg = (
+            "Cannot delete some instances of model 'R' because they are "
+            "referenced through a restricted foreign key: 'A.restrict'."
+        )
+        with self.assertRaisesMessage(RestrictedError, msg):
+            a.restrict.p.delete()
+        # Object referenced also with CASCADE relationship can be deleted.
+        a.cascade.p = a.restrict.p
+        a.cascade.save()
+        a.restrict.p.delete()
+        self.assertFalse(A.objects.filter(name='restrict').exists())
+        self.assertFalse(R.objects.filter(pk=a.restrict_id).exists())
+
+    def test_restrict_path_cascade_direct(self):
+        a = create_a('restrict')
+        a.restrict.p = P.objects.create()
+        a.restrict.save()
+        a.cascade_p = a.restrict.p
+        a.save()
+        a.restrict.p.delete()
+        self.assertFalse(A.objects.filter(name='restrict').exists())
+        self.assertFalse(R.objects.filter(pk=a.restrict_id).exists())
+
+    def test_restrict_path_cascade_indirect_diamond(self):
+        delete_top = DeleteTop.objects.create()
+        b1 = B1.objects.create(delete_top=delete_top)
+        b2 = B2.objects.create(delete_top=delete_top)
+        DeleteBottom.objects.create(b1=b1, b2=b2)
+        msg = (
+            "Cannot delete some instances of model 'B1' because they are "
+            "referenced through a restricted foreign key: 'DeleteBottom.b1'."
+        )
+        with self.assertRaisesMessage(RestrictedError, msg):
+            b1.delete()
+        self.assertTrue(DeleteTop.objects.exists())
+        self.assertTrue(B1.objects.exists())
+        self.assertTrue(B2.objects.exists())
+        self.assertTrue(DeleteBottom.objects.exists())
+        # Object referenced also with CASCADE relationship can be deleted.
+        delete_top.delete()
+        self.assertFalse(DeleteTop.objects.exists())
+        self.assertFalse(B1.objects.exists())
+        self.assertFalse(B2.objects.exists())
+        self.assertFalse(DeleteBottom.objects.exists())
+
+    def test_restrict_gfk_no_fast_delete(self):
+        delete_top = DeleteTop.objects.create()
+        generic_b1 = GenericB1.objects.create(generic_delete_top=delete_top)
+        generic_b2 = GenericB2.objects.create(generic_delete_top=delete_top)
+        GenericDeleteBottom.objects.create(generic_b1=generic_b1, generic_b2=generic_b2)
+        msg = (
+            "Cannot delete some instances of model 'GenericB1' because they "
+            "are referenced through a restricted foreign key: "
+            "'GenericDeleteBottom.generic_b1'."
+        )
+        with self.assertRaisesMessage(RestrictedError, msg):
+            generic_b1.delete()
+        self.assertTrue(DeleteTop.objects.exists())
+        self.assertTrue(GenericB1.objects.exists())
+        self.assertTrue(GenericB2.objects.exists())
+        self.assertTrue(GenericDeleteBottom.objects.exists())
+        # Object referenced also with CASCADE relationship can be deleted.
+        delete_top.delete()
+        self.assertFalse(DeleteTop.objects.exists())
+        self.assertFalse(GenericB1.objects.exists())
+        self.assertFalse(GenericB2.objects.exists())
+        self.assertFalse(GenericDeleteBottom.objects.exists())
+
 
 class DeletionTests(TestCase):