Răsfoiți Sursa

Fixed #15161 - Corrected handling of ManyToManyField with through table using to_field on its ForeignKeys. Thanks to adehnert for the report.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@15330 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Carl Meyer 14 ani în urmă
părinte
comite
84291b7b84

+ 6 - 0
django/contrib/contenttypes/generic.py

@@ -135,6 +135,12 @@ class GenericRelation(RelatedField, Field):
     def m2m_reverse_name(self):
         return self.rel.to._meta.pk.column
 
+    def m2m_target_field_name(self):
+        return self.model._meta.pk.name
+
+    def m2m_reverse_target_field_name(self):
+        return self.rel.to._meta.pk.name
+
     def contribute_to_class(self, cls, name):
         super(GenericRelation, self).contribute_to_class(cls, name)
 

+ 5 - 0
django/db/models/fields/related.py

@@ -1131,6 +1131,11 @@ class ManyToManyField(RelatedField, Field):
         self.m2m_field_name = curry(self._get_m2m_attr, related, 'name')
         self.m2m_reverse_field_name = curry(self._get_m2m_reverse_attr, related, 'name')
 
+        get_m2m_rel = curry(self._get_m2m_attr, related, 'rel')
+        self.m2m_target_field_name = lambda: get_m2m_rel().field_name
+        get_m2m_reverse_rel = curry(self._get_m2m_reverse_attr, related, 'rel')
+        self.m2m_reverse_target_field_name = lambda: get_m2m_reverse_rel().field_name
+
     def set_attributes_from_rel(self):
         pass
 

+ 8 - 4
django/db/models/sql/query.py

@@ -1282,12 +1282,14 @@ class Query(object):
                                 to_col2, opts, target) = cached_data
                     else:
                         table1 = field.m2m_db_table()
-                        from_col1 = opts.pk.column
+                        from_col1 = opts.get_field_by_name(
+                            field.m2m_target_field_name())[0].column
                         to_col1 = field.m2m_column_name()
                         opts = field.rel.to._meta
                         table2 = opts.db_table
                         from_col2 = field.m2m_reverse_name()
-                        to_col2 = opts.pk.column
+                        to_col2 = opts.get_field_by_name(
+                            field.m2m_reverse_target_field_name())[0].column
                         target = opts.pk
                         orig_opts._join_cache[name] = (table1, from_col1,
                                 to_col1, table2, from_col2, to_col2, opts,
@@ -1335,12 +1337,14 @@ class Query(object):
                                 to_col2, opts, target) = cached_data
                     else:
                         table1 = field.m2m_db_table()
-                        from_col1 = opts.pk.column
+                        from_col1 = opts.get_field_by_name(
+                            field.m2m_reverse_target_field_name())[0].column
                         to_col1 = field.m2m_reverse_name()
                         opts = orig_field.opts
                         table2 = opts.db_table
                         from_col2 = field.m2m_column_name()
-                        to_col2 = opts.pk.column
+                        to_col2 = opts.get_field_by_name(
+                            field.m2m_target_field_name())[0].column
                         target = opts.pk
                         orig_opts._join_cache[name] = (table1, from_col1,
                                 to_col1, table2, from_col2, to_col2, opts,

+ 22 - 0
tests/regressiontests/m2m_through_regress/models.py

@@ -53,3 +53,25 @@ class Through(ThroughBase):
 class B(models.Model):
     b_text = models.CharField(max_length=20)
     a_list = models.ManyToManyField(A, through=Through)
+
+
+# Using to_field on the through model
+class Car(models.Model):
+    make = models.CharField(max_length=20, unique=True)
+    drivers = models.ManyToManyField('Driver', through='CarDriver')
+
+    def __unicode__(self, ):
+        return self.make
+
+class Driver(models.Model):
+    name = models.CharField(max_length=20, unique=True)
+
+    def __unicode__(self, ):
+        return self.name
+
+class CarDriver(models.Model):
+    car = models.ForeignKey('Car', to_field='make')
+    driver = models.ForeignKey('Driver', to_field='name')
+
+    def __unicode__(self, ):
+        return u"pk=%s car=%s driver=%s" % (str(self.pk), self.car, self.driver)

+ 21 - 1
tests/regressiontests/m2m_through_regress/tests.py

@@ -7,7 +7,8 @@ from django.core import management
 from django.contrib.auth.models import User
 from django.test import TestCase
 
-from models import Person, Group, Membership, UserMembership
+from models import (Person, Group, Membership, UserMembership,
+                    Car, Driver, CarDriver)
 
 
 class M2MThroughTestCase(TestCase):
@@ -118,6 +119,25 @@ class M2MThroughTestCase(TestCase):
             ]
         )
 
+
+class ToFieldThroughTests(TestCase):
+    def setUp(self):
+        self.car = Car.objects.create(make="Toyota")
+        self.driver = Driver.objects.create(name="Ryan Briscoe")
+        CarDriver.objects.create(car=self.car, driver=self.driver)
+
+    def test_to_field(self):
+        self.assertQuerysetEqual(
+            self.car.drivers.all(),
+            ["<Driver: Ryan Briscoe>"]
+            )
+
+    def test_to_field_reverse(self):
+        self.assertQuerysetEqual(
+            self.driver.car_set.all(),
+            ["<Car: Toyota>"]
+            )
+
 class ThroughLoadDataTestCase(TestCase):
     fixtures = ["m2m_through"]