Browse Source

Fixed #24858 -- Added support for get_FOO_display() to ArrayField and RangeFields.

_get_FIELD_display() crashed when Field.choices was unhashable.
Hasan Ramezani 5 years ago
parent
commit
153c7956f8

+ 3 - 1
django/db/models/base.py

@@ -33,6 +33,7 @@ from django.db.models.signals import (
 )
 from django.db.models.utils import make_model_tuple
 from django.utils.encoding import force_str
+from django.utils.hashable import make_hashable
 from django.utils.text import capfirst, get_text_list
 from django.utils.translation import gettext_lazy as _
 from django.utils.version import get_version
@@ -940,8 +941,9 @@ class Model(metaclass=ModelBase):
 
     def _get_FIELD_display(self, field):
         value = getattr(self, field.attname)
+        choices_dict = dict(make_hashable(field.flatchoices))
         # force_str() to coerce lazy strings.
-        return force_str(dict(field.flatchoices).get(value, value), strings_only=True)
+        return force_str(choices_dict.get(make_hashable(value), value), strings_only=True)
 
     def _get_next_or_previous_by_FIELD(self, field, is_next, **kwargs):
         if not self.pk:

+ 5 - 0
docs/ref/models/instances.txt

@@ -797,6 +797,11 @@ For example::
     >>> p.get_shirt_size_display()
     'Large'
 
+.. versionchanged:: 3.1
+
+    Support for :class:`~django.contrib.postgres.fields.ArrayField` and
+    :class:`~django.contrib.postgres.fields.RangeField` was added.
+
 .. method:: Model.get_next_by_FOO(**kwargs)
 .. method:: Model.get_previous_by_FOO(**kwargs)
 

+ 4 - 0
docs/releases/3.1.txt

@@ -76,6 +76,10 @@ Minor features
   :class:`~django.contrib.postgres.operations.BloomExtension` migration
   operation installs the ``bloom`` extension to add support for this index.
 
+* :meth:`~django.db.models.Model.get_FOO_display` now supports
+  :class:`~django.contrib.postgres.fields.ArrayField` and
+  :class:`~django.contrib.postgres.fields.RangeField`.
+
 :mod:`django.contrib.redirects`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 

+ 47 - 0
tests/postgres_tests/test_array.py

@@ -37,6 +37,53 @@ except ImportError:
     pass
 
 
+@isolate_apps('postgres_tests')
+class BasicTests(PostgreSQLSimpleTestCase):
+    def test_get_field_display(self):
+        class MyModel(PostgreSQLModel):
+            field = ArrayField(
+                models.CharField(max_length=16),
+                choices=[
+                    ['Media', [(['vinyl', 'cd'], 'Audio')]],
+                    (('mp3', 'mp4'), 'Digital'),
+                ],
+            )
+
+        tests = (
+            (['vinyl', 'cd'], 'Audio'),
+            (('mp3', 'mp4'), 'Digital'),
+            (('a', 'b'), "('a', 'b')"),
+            (['c', 'd'], "['c', 'd']"),
+        )
+        for value, display in tests:
+            with self.subTest(value=value, display=display):
+                instance = MyModel(field=value)
+                self.assertEqual(instance.get_field_display(), display)
+
+    def test_get_field_display_nested_array(self):
+        class MyModel(PostgreSQLModel):
+            field = ArrayField(
+                ArrayField(models.CharField(max_length=16)),
+                choices=[
+                    [
+                        'Media',
+                        [([['vinyl', 'cd'], ('x',)], 'Audio')],
+                    ],
+                    ((['mp3'], ('mp4',)), 'Digital'),
+                ],
+            )
+        tests = (
+            ([['vinyl', 'cd'], ('x',)], 'Audio'),
+            ((['mp3'], ('mp4',)), 'Digital'),
+            ((('a', 'b'), ('c',)), "(('a', 'b'), ('c',))"),
+            ([['a', 'b'], ['c']], "[['a', 'b'], ['c']]"),
+        )
+        for value, display in tests:
+            with self.subTest(value=value, display=display):
+                instance = MyModel(field=value)
+                self.assertEqual(instance.get_field_display(), display)
+
+
 class TestSaveLoad(PostgreSQLTestCase):
 
     def test_integer(self):

+ 25 - 0
tests/postgres_tests/test_ranges.py

@@ -7,6 +7,7 @@ from django.core import exceptions, serializers
 from django.db.models import DateField, DateTimeField, F, Func, Value
 from django.http import QueryDict
 from django.test import override_settings
+from django.test.utils import isolate_apps
 from django.utils import timezone
 
 from . import PostgreSQLSimpleTestCase, PostgreSQLTestCase
@@ -22,6 +23,30 @@ except ImportError:
     pass
 
 
+@isolate_apps('postgres_tests')
+class BasicTests(PostgreSQLSimpleTestCase):
+    def test_get_field_display(self):
+        class Model(PostgreSQLModel):
+            field = pg_fields.IntegerRangeField(
+                choices=[
+                    ['1-50', [((1, 25), '1-25'), ([26, 50], '26-50')]],
+                    ((51, 100), '51-100'),
+                ],
+            )
+
+        tests = (
+            ((1, 25), '1-25'),
+            ([26, 50], '26-50'),
+            ((51, 100), '51-100'),
+            ((1, 2), '(1, 2)'),
+            ([1, 2], '[1, 2]'),
+        )
+        for value, display in tests:
+            with self.subTest(value=value, display=display):
+                instance = Model(field=value)
+                self.assertEqual(instance.get_field_display(), display)
+
+
 class TestSaveLoad(PostgreSQLTestCase):
 
     def test_all_fields(self):