Browse Source

Made the caching of related and reverse related objects consistent in OneToOneFields. Fixed #13839. Refs #17439.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@17890 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Aymeric Augustin 13 years ago
parent
commit
b90d4e5b74

+ 32 - 23
django/db/models/fields/related.py

@@ -249,11 +249,19 @@ class SingleRelatedObjectDescriptor(object):
         if instance is None:
             return self
         try:
-            return getattr(instance, self.cache_name)
+            rel_obj = getattr(instance, self.cache_name)
         except AttributeError:
             params = {'%s__pk' % self.related.field.name: instance._get_pk_val()}
-            rel_obj = self.get_query_set(instance=instance).get(**params)
+            try:
+                rel_obj = self.get_query_set(instance=instance).get(**params)
+            except self.related.model.DoesNotExist:
+                rel_obj = None
+            else:
+                setattr(rel_obj, self.related.field.get_cache_name(), instance)
             setattr(instance, self.cache_name, rel_obj)
+        if rel_obj is None:
+            raise self.related.model.DoesNotExist
+        else:
             return rel_obj
 
     def __set__(self, instance, value):
@@ -331,24 +339,27 @@ class ReverseSingleRelatedObjectDescriptor(object):
     def __get__(self, instance, instance_type=None):
         if instance is None:
             return self
-
         try:
-            return getattr(instance, self.cache_name)
+            rel_obj = getattr(instance, self.cache_name)
         except AttributeError:
             val = getattr(instance, self.field.attname)
             if val is None:
-                # If NULL is an allowed value, return it.
-                if self.field.null:
-                    return None
-                raise self.field.rel.to.DoesNotExist
-            other_field = self.field.rel.get_related_field()
-            if other_field.rel:
-                params = {'%s__pk' % self.field.rel.field_name: val}
+                rel_obj = None
             else:
-                params = {'%s__exact' % self.field.rel.field_name: val}
-            qs = self.get_query_set(instance=instance)
-            rel_obj = qs.get(**params)
+                other_field = self.field.rel.get_related_field()
+                if other_field.rel:
+                    params = {'%s__pk' % self.field.rel.field_name: val}
+                else:
+                    params = {'%s__exact' % self.field.rel.field_name: val}
+                qs = self.get_query_set(instance=instance)
+                # Assuming the database enforces foreign keys, this won't fail.
+                rel_obj = qs.get(**params)
+                if not self.field.rel.multiple:
+                    setattr(rel_obj, self.field.related.get_cache_name(), instance)
             setattr(instance, self.cache_name, rel_obj)
+        if rel_obj is None and not self.field.null:
+            raise self.field.rel.to.DoesNotExist
+        else:
             return rel_obj
 
     def __set__(self, instance, value):
@@ -385,17 +396,13 @@ class ReverseSingleRelatedObjectDescriptor(object):
             # populated the cache, then we don't care - we're only accessing
             # the object to invalidate the accessor cache, so there's no
             # need to populate the cache just to expire it again.
-            related = getattr(instance, self.field.get_cache_name(), None)
+            related = getattr(instance, self.cache_name, None)
 
             # If we've got an old related object, we need to clear out its
             # cache. This cache also might not exist if the related object
             # hasn't been accessed yet.
-            if related:
-                cache_name = self.field.related.get_cache_name()
-                try:
-                    delattr(related, cache_name)
-                except AttributeError:
-                    pass
+            if related is not None:
+                setattr(related, self.field.related.get_cache_name(), None)
 
         # Set the value of the related field
         try:
@@ -405,9 +412,11 @@ class ReverseSingleRelatedObjectDescriptor(object):
         setattr(instance, self.field.attname, val)
 
         # Since we already know what the related object is, seed the related
-        # object cache now, too. This avoids another db hit if you get the
+        # object caches now, too. This avoids another db hit if you get the
         # object you just set.
-        setattr(instance, self.field.get_cache_name(), value)
+        setattr(instance, self.cache_name, value)
+        if value is not None and not self.field.rel.multiple:
+            setattr(value, self.field.related.get_cache_name(), instance)
 
 class ForeignRelatedObjectsDescriptor(object):
     # This class provides the functionality that makes the related-object

+ 70 - 0
tests/regressiontests/one_to_one_regress/tests.py

@@ -132,3 +132,73 @@ class OneToOneRegressionTests(TestCase):
                 Target.objects.exclude(pointer2=None),
                 []
         )
+
+    def test_reverse_object_does_not_exist_cache(self):
+        """
+        Regression for #13839 and #17439.
+
+        DoesNotExist on a reverse one-to-one relation is cached.
+        """
+        p = Place(name='Zombie Cats', address='Not sure')
+        p.save()
+        with self.assertNumQueries(1):
+            with self.assertRaises(Restaurant.DoesNotExist):
+                p.restaurant
+        with self.assertNumQueries(0):
+            with self.assertRaises(Restaurant.DoesNotExist):
+                p.restaurant
+
+    def test_reverse_object_cached_when_related_is_accessed(self):
+        """
+        Regression for #13839 and #17439.
+
+        The target of a one-to-one relation is cached
+        when the origin is accessed through the reverse relation.
+        """
+        # Use a fresh object without caches
+        r = Restaurant.objects.get(pk=self.r1.pk)
+        p = r.place
+        with self.assertNumQueries(0):
+            self.assertEqual(p.restaurant, r)
+
+    def test_related_object_cached_when_reverse_is_accessed(self):
+        """
+        Regression for #13839 and #17439.
+
+        The origin of a one-to-one relation is cached
+        when the target is accessed through the reverse relation.
+        """
+        # Use a fresh object without caches
+        p = Place.objects.get(pk=self.p1.pk)
+        r = p.restaurant
+        with self.assertNumQueries(0):
+            self.assertEqual(r.place, p)
+
+    def test_reverse_object_cached_when_related_is_set(self):
+        """
+        Regression for #13839 and #17439.
+
+        The target of a one-to-one relation is always cached.
+        """
+        p = Place(name='Zombie Cats', address='Not sure')
+        p.save()
+        self.r1.place = p
+        self.r1.save()
+        with self.assertNumQueries(0):
+            self.assertEqual(p.restaurant, self.r1)
+
+    def test_reverse_object_cached_when_related_is_unset(self):
+        """
+        Regression for #13839 and #17439.
+
+        The target of a one-to-one relation is always cached.
+        """
+        b = UndergroundBar(place=self.p1, serves_cocktails=True)
+        b.save()
+        with self.assertNumQueries(0):
+            self.assertEqual(self.p1.undergroundbar, b)
+        b.place = None
+        b.save()
+        with self.assertNumQueries(0):
+            with self.assertRaises(UndergroundBar.DoesNotExist):
+                self.p1.undergroundbar

+ 29 - 1
tests/regressiontests/select_related_onetoone/tests.py

@@ -79,4 +79,32 @@ class ReverseSelectRelatedTestCase(TestCase):
         p1 = Product.objects.create(name="Django Plushie", image=im)
         p2 = Product.objects.create(name="Talking Django Plushie")
 
-        self.assertEqual(len(Product.objects.select_related("image")), 2)
+        with self.assertNumQueries(1):
+            result = sorted(Product.objects.select_related("image"), key=lambda x: x.name)
+            self.assertEqual([p.name for p in result], ["Django Plushie", "Talking Django Plushie"])
+
+            self.assertEqual(p1.image, im)
+            # Check for ticket #13839
+            self.assertIsNone(p2.image)
+
+    def test_missing_reverse(self):
+        """
+        Ticket #13839: select_related() should NOT cache None
+        for missing objects on a reverse 1-1 relation.
+        """
+        with self.assertNumQueries(1):
+            user = User.objects.select_related('userprofile').get(username='bob')
+            with self.assertRaises(UserProfile.DoesNotExist):
+                user.userprofile
+
+    def test_nullable_missing_reverse(self):
+        """
+        Ticket #13839: select_related() should NOT cache None
+        for missing objects on a reverse 0-1 relation.
+        """
+        Image.objects.create(name="imag1")
+
+        with self.assertNumQueries(1):
+            image = Image.objects.select_related('product').get()
+            with self.assertRaises(Product.DoesNotExist):
+                image.product