瀏覽代碼

Fixed #16055 -- Fixed crash when filtering against char/text GenericRelation relation on PostgreSQL.

David Wobrock 1 年之前
父節點
當前提交
9bbf97bcdb

+ 7 - 0
django/db/backends/base/operations.py

@@ -8,6 +8,7 @@ import sqlparse
 from django.conf import settings
 from django.db import NotSupportedError, transaction
 from django.db.backends import utils
+from django.db.models.expressions import Col
 from django.utils import timezone
 from django.utils.encoding import force_str
 
@@ -776,3 +777,9 @@ class BaseDatabaseOperations:
 
     def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
         return ""
+
+    def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
+        lhs_expr = Col(lhs_table, lhs_field)
+        rhs_expr = Col(rhs_table, rhs_field)
+
+        return lhs_expr, rhs_expr

+ 3 - 0
django/db/backends/oracle/features.py

@@ -120,6 +120,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
             "migrations.test_operations.OperationTests."
             "test_alter_field_pk_fk_db_collation",
         },
+        "Oracle doesn't support comparing NCLOB to NUMBER.": {
+            "generic_relations_regress.tests.GenericRelationTests.test_textlink_filter",
+        },
     }
     django_test_expected_failures = {
         # A bug in Django/cx_Oracle with respect to string handling (#23843).

+ 11 - 0
django/db/backends/postgresql/operations.py

@@ -12,6 +12,7 @@ from django.db.backends.postgresql.psycopg_any import (
 )
 from django.db.backends.utils import split_tzname_delta
 from django.db.models.constants import OnConflict
+from django.db.models.functions import Cast
 from django.utils.regex_helper import _lazy_re_compile
 
 
@@ -413,3 +414,13 @@ class DatabaseOperations(BaseDatabaseOperations):
             update_fields,
             unique_fields,
         )
+
+    def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
+        lhs_expr, rhs_expr = super().prepare_join_on_clause(
+            lhs_table, lhs_field, rhs_table, rhs_field
+        )
+
+        if lhs_field.db_type(self.connection) != rhs_field.db_type(self.connection):
+            rhs_expr = Cast(rhs_expr, lhs_field)
+
+        return lhs_expr, rhs_expr

+ 8 - 0
django/db/models/fields/related.py

@@ -785,6 +785,14 @@ class ForeignObject(RelatedField):
     def get_reverse_joining_columns(self):
         return self.get_joining_columns(reverse_join=True)
 
+    def get_joining_fields(self, reverse_join=False):
+        return tuple(
+            self.reverse_related_fields if reverse_join else self.related_fields
+        )
+
+    def get_reverse_joining_fields(self):
+        return self.get_joining_fields(reverse_join=True)
+
     def get_extra_descriptor_filter(self, instance):
         """
         Return an extra filter condition for related object fetching when

+ 3 - 0
django/db/models/fields/reverse_related.py

@@ -195,6 +195,9 @@ class ForeignObjectRel(FieldCacheMixin):
     def get_joining_columns(self):
         return self.field.get_reverse_joining_columns()
 
+    def get_joining_fields(self):
+        return self.field.get_reverse_joining_fields()
+
     def get_extra_restriction(self, alias, related_alias):
         return self.field.get_extra_restriction(related_alias, alias)
 

+ 22 - 11
django/db/models/sql/datastructures.py

@@ -61,7 +61,15 @@ class Join:
         self.join_type = join_type
         # A list of 2-tuples to use in the ON clause of the JOIN.
         # Each 2-tuple will create one join condition in the ON clause.
-        self.join_cols = join_field.get_joining_columns()
+        if hasattr(join_field, "get_joining_fields"):
+            self.join_fields = join_field.get_joining_fields()
+            self.join_cols = tuple(
+                (lhs_field.column, rhs_field.column)
+                for lhs_field, rhs_field in self.join_fields
+            )
+        else:
+            self.join_fields = None
+            self.join_cols = join_field.get_joining_columns()
         # Along which field (or ForeignObjectRel in the reverse join case)
         self.join_field = join_field
         # Is this join nullabled?
@@ -78,18 +86,21 @@ class Join:
         params = []
         qn = compiler.quote_name_unless_alias
         qn2 = connection.ops.quote_name
-
         # Add a join condition for each pair of joining columns.
-        for lhs_col, rhs_col in self.join_cols:
-            join_conditions.append(
-                "%s.%s = %s.%s"
-                % (
-                    qn(self.parent_alias),
-                    qn2(lhs_col),
-                    qn(self.table_alias),
-                    qn2(rhs_col),
+        join_fields = self.join_fields or self.join_cols
+        for lhs, rhs in join_fields:
+            if isinstance(lhs, str):
+                lhs_full_name = "%s.%s" % (qn(self.parent_alias), qn2(lhs))
+                rhs_full_name = "%s.%s" % (qn(self.table_alias), qn2(rhs))
+            else:
+                lhs, rhs = connection.ops.prepare_join_on_clause(
+                    self.parent_alias, lhs, self.table_alias, rhs
                 )
-            )
+                lhs_sql, lhs_params = compiler.compile(lhs)
+                lhs_full_name = lhs_sql % lhs_params
+                rhs_sql, rhs_params = compiler.compile(rhs)
+                rhs_full_name = rhs_sql % rhs_params
+            join_conditions.append(f"{lhs_full_name} = {rhs_full_name}")
 
         # Add a single condition inside parentheses for whatever
         # get_extra_restriction() returns.

+ 15 - 0
tests/backends/base/test_operations.py

@@ -4,6 +4,7 @@ from django.core.management.color import no_style
 from django.db import NotSupportedError, connection, transaction
 from django.db.backends.base.operations import BaseDatabaseOperations
 from django.db.models import DurationField, Value
+from django.db.models.expressions import Col
 from django.test import (
     SimpleTestCase,
     TestCase,
@@ -159,6 +160,20 @@ class SimpleDatabaseOperationTests(SimpleTestCase):
         ):
             self.ops.datetime_extract_sql(None, None, None, None)
 
+    def test_prepare_join_on_clause(self):
+        author_table = Author._meta.db_table
+        author_id_field = Author._meta.get_field("id")
+        book_table = Book._meta.db_table
+        book_fk_field = Book._meta.get_field("author")
+        lhs_expr, rhs_expr = self.ops.prepare_join_on_clause(
+            author_table,
+            author_id_field,
+            book_table,
+            book_fk_field,
+        )
+        self.assertEqual(lhs_expr, Col(author_table, author_id_field))
+        self.assertEqual(rhs_expr, Col(book_table, book_fk_field))
+
 
 class DatabaseOperationTests(TestCase):
     def setUp(self):

+ 31 - 1
tests/backends/postgresql/test_operations.py

@@ -2,9 +2,11 @@ import unittest
 
 from django.core.management.color import no_style
 from django.db import connection
+from django.db.models.expressions import Col
+from django.db.models.functions import Cast
 from django.test import SimpleTestCase
 
-from ..models import Person, Tag
+from ..models import Author, Book, Person, Tag
 
 
 @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests.")
@@ -48,3 +50,31 @@ class PostgreSQLOperationsTests(SimpleTestCase):
             ),
             ['TRUNCATE "backends_person", "backends_tag" RESTART IDENTITY CASCADE;'],
         )
+
+    def test_prepare_join_on_clause_same_type(self):
+        author_table = Author._meta.db_table
+        author_id_field = Author._meta.get_field("id")
+        lhs_expr, rhs_expr = connection.ops.prepare_join_on_clause(
+            author_table,
+            author_id_field,
+            author_table,
+            author_id_field,
+        )
+        self.assertEqual(lhs_expr, Col(author_table, author_id_field))
+        self.assertEqual(rhs_expr, Col(author_table, author_id_field))
+
+    def test_prepare_join_on_clause_different_types(self):
+        author_table = Author._meta.db_table
+        author_id_field = Author._meta.get_field("id")
+        book_table = Book._meta.db_table
+        book_fk_field = Book._meta.get_field("author")
+        lhs_expr, rhs_expr = connection.ops.prepare_join_on_clause(
+            author_table,
+            author_id_field,
+            book_table,
+            book_fk_field,
+        )
+        self.assertEqual(lhs_expr, Col(author_table, author_id_field))
+        self.assertEqual(
+            rhs_expr, Cast(Col(book_table, book_fk_field), author_id_field)
+        )

+ 1 - 1
tests/foreign_object/models/empty_join.py

@@ -50,7 +50,7 @@ class StartsWithRelation(models.ForeignObject):
         from_field = self.model._meta.get_field(self.from_fields[0])
         return StartsWith(to_field.get_col(alias), from_field.get_col(related_alias))
 
-    def get_joining_columns(self, reverse_join=False):
+    def get_joining_fields(self, reverse_join=False):
         return ()
 
     def get_path_info(self, filtered_relation=None):

+ 2 - 0
tests/generic_relations_regress/models.py

@@ -64,12 +64,14 @@ class CharLink(models.Model):
     content_type = models.ForeignKey(ContentType, models.CASCADE)
     object_id = models.CharField(max_length=100)
     content_object = GenericForeignKey()
+    value = models.CharField(max_length=250)
 
 
 class TextLink(models.Model):
     content_type = models.ForeignKey(ContentType, models.CASCADE)
     object_id = models.TextField()
     content_object = GenericForeignKey()
+    value = models.CharField(max_length=250)
 
 
 class OddRelation1(models.Model):

+ 14 - 0
tests/generic_relations_regress/tests.py

@@ -72,6 +72,20 @@ class GenericRelationTests(TestCase):
         TextLink.objects.create(content_object=oddrel)
         oddrel.delete()
 
+    def test_charlink_filter(self):
+        oddrel = OddRelation1.objects.create(name="clink")
+        CharLink.objects.create(content_object=oddrel, value="value")
+        self.assertSequenceEqual(
+            OddRelation1.objects.filter(clinks__value="value"), [oddrel]
+        )
+
+    def test_textlink_filter(self):
+        oddrel = OddRelation2.objects.create(name="clink")
+        TextLink.objects.create(content_object=oddrel, value="value")
+        self.assertSequenceEqual(
+            OddRelation2.objects.filter(tlinks__value="value"), [oddrel]
+        )
+
     def test_coerce_object_id_remote_field_cache_persistence(self):
         restaurant = Restaurant.objects.create()
         CharLink.objects.create(content_object=restaurant)