瀏覽代碼

Refs #27095 -- Allowed (non-nested) arrays containing expressions for ArrayField lookups.

Hannes Ljungberg 4 年之前
父節點
當前提交
33403bf80f
共有 3 個文件被更改,包括 53 次插入14 次删除
  1. 21 6
      django/contrib/postgres/fields/array.py
  2. 3 0
      docs/releases/3.2.txt
  3. 29 8
      tests/postgres_tests/test_array.py

+ 21 - 6
django/contrib/postgres/fields/array.py

@@ -4,7 +4,7 @@ from django.contrib.postgres import lookups
 from django.contrib.postgres.forms import SimpleArrayField
 from django.contrib.postgres.validators import ArrayMaxLengthValidator
 from django.core import checks, exceptions
-from django.db.models import Field, IntegerField, Transform
+from django.db.models import Field, Func, IntegerField, Transform, Value
 from django.db.models.fields.mixins import CheckFieldDefaultMixin
 from django.db.models.lookups import Exact, In
 from django.utils.translation import gettext_lazy as _
@@ -198,7 +198,22 @@ class ArrayField(CheckFieldDefaultMixin, Field):
         })
 
 
-class ArrayCastRHSMixin:
+class ArrayRHSMixin:
+    def __init__(self, lhs, rhs):
+        if isinstance(rhs, (tuple, list)):
+            expressions = []
+            for value in rhs:
+                if not hasattr(value, 'resolve_expression'):
+                    field = lhs.output_field
+                    value = Value(field.base_field.get_prep_value(value))
+                expressions.append(value)
+            rhs = Func(
+                *expressions,
+                function='ARRAY',
+                template='%(function)s[%(expressions)s]',
+            )
+        super().__init__(lhs, rhs)
+
     def process_rhs(self, compiler, connection):
         rhs, rhs_params = super().process_rhs(compiler, connection)
         cast_type = self.lhs.output_field.cast_db_type(connection)
@@ -206,22 +221,22 @@ class ArrayCastRHSMixin:
 
 
 @ArrayField.register_lookup
-class ArrayContains(ArrayCastRHSMixin, lookups.DataContains):
+class ArrayContains(ArrayRHSMixin, lookups.DataContains):
     pass
 
 
 @ArrayField.register_lookup
-class ArrayContainedBy(ArrayCastRHSMixin, lookups.ContainedBy):
+class ArrayContainedBy(ArrayRHSMixin, lookups.ContainedBy):
     pass
 
 
 @ArrayField.register_lookup
-class ArrayExact(ArrayCastRHSMixin, Exact):
+class ArrayExact(ArrayRHSMixin, Exact):
     pass
 
 
 @ArrayField.register_lookup
-class ArrayOverlap(ArrayCastRHSMixin, lookups.Overlap):
+class ArrayOverlap(ArrayRHSMixin, lookups.Overlap):
     pass
 
 

+ 3 - 0
docs/releases/3.2.txt

@@ -143,6 +143,9 @@ Minor features
   allow creating and dropping collations on PostgreSQL. See
   :ref:`manage-postgresql-collations` for more details.
 
+* Lookups for :class:`~django.contrib.postgres.fields.ArrayField` now allow
+  (non-nested) arrays containing expressions as right-hand sides.
+
 :mod:`django.contrib.redirects`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

+ 29 - 8
tests/postgres_tests/test_array.py

@@ -9,8 +9,8 @@ from django.core import checks, exceptions, serializers, validators
 from django.core.exceptions import FieldError
 from django.core.management import call_command
 from django.db import IntegrityError, connection, models
-from django.db.models.expressions import Exists, OuterRef, RawSQL
-from django.db.models.functions import Cast
+from django.db.models.expressions import Exists, OuterRef, RawSQL, Value
+from django.db.models.functions import Cast, Upper
 from django.test import TransactionTestCase, modify_settings, override_settings
 from django.test.utils import isolate_apps
 from django.utils import timezone
@@ -226,6 +226,12 @@ class TestQuerying(PostgreSQLTestCase):
             self.objs[:1]
         )
 
+    def test_exact_with_expression(self):
+        self.assertSequenceEqual(
+            NullableIntegerArrayModel.objects.filter(field__exact=[Value(1)]),
+            self.objs[:1],
+        )
+
     def test_exact_charfield(self):
         instance = CharArrayModel.objects.create(field=['text'])
         self.assertSequenceEqual(
@@ -296,15 +302,10 @@ class TestQuerying(PostgreSQLTestCase):
             self.objs[:2]
         )
 
-    @unittest.expectedFailure
     def test_contained_by_including_F_object(self):
-        # This test asserts that Array objects passed to filters can be
-        # constructed to contain F objects. This currently doesn't work as the
-        # psycopg2 mogrify method that generates the ARRAY() syntax is
-        # expecting literals, not column references (#27095).
         self.assertSequenceEqual(
             NullableIntegerArrayModel.objects.filter(field__contained_by=[models.F('id'), 2]),
-            self.objs[:2]
+            self.objs[:3],
         )
 
     def test_contains(self):
@@ -326,6 +327,14 @@ class TestQuerying(PostgreSQLTestCase):
             self.objs[1:3],
         )
 
+    def test_contains_including_expression(self):
+        self.assertSequenceEqual(
+            NullableIntegerArrayModel.objects.filter(
+                field__contains=[2, Value(6) / Value(2)],
+            ),
+            self.objs[2:3],
+        )
+
     def test_icontains(self):
         # Using the __icontains lookup with ArrayField is inefficient.
         instance = CharArrayModel.objects.create(field=['FoO'])
@@ -353,6 +362,18 @@ class TestQuerying(PostgreSQLTestCase):
             []
         )
 
+    def test_overlap_charfield_including_expression(self):
+        obj_1 = CharArrayModel.objects.create(field=['TEXT', 'lower text'])
+        obj_2 = CharArrayModel.objects.create(field=['lower text', 'TEXT'])
+        CharArrayModel.objects.create(field=['lower text', 'text'])
+        self.assertSequenceEqual(
+            CharArrayModel.objects.filter(field__overlap=[
+                Upper(Value('text')),
+                'other',
+            ]),
+            [obj_1, obj_2],
+        )
+
     def test_lookups_autofield_array(self):
         qs = NullableIntegerArrayModel.objects.filter(
             field__0__isnull=False,