Ver Fonte

Fixed #32483 -- Fixed QuerySet.values()/values_list() on JSONField key transforms with booleans on SQLite.

Thanks Matthew Cornell for the report.
Mariusz Felisiak há 4 anos atrás
pai
commit
71ec102b01

+ 3 - 0
django/db/backends/sqlite3/operations.py

@@ -21,6 +21,9 @@ class DatabaseOperations(BaseDatabaseOperations):
         'DateTimeField': 'TEXT',
     }
     explain_prefix = 'EXPLAIN QUERY PLAN'
+    # List of datatypes to that cannot be extracted with JSON_EXTRACT() on
+    # SQLite. Use JSON_TYPE() instead.
+    jsonfield_datatype_values = frozenset(['null', 'false', 'true'])
 
     def bulk_batch_size(self, fields, objs):
         """

+ 17 - 21
django/db/models/fields/json.py

@@ -260,15 +260,6 @@ class CaseInsensitiveMixin:
 class JSONExact(lookups.Exact):
     can_use_none_as_rhs = True
 
-    def process_lhs(self, compiler, connection):
-        lhs, lhs_params = super().process_lhs(compiler, connection)
-        if connection.vendor == 'sqlite':
-            rhs, rhs_params = super().process_rhs(compiler, connection)
-            if rhs == '%s' and rhs_params == [None]:
-                # Use JSON_TYPE instead of JSON_EXTRACT for NULLs.
-                lhs = "JSON_TYPE(%s, '$')" % lhs
-        return lhs, lhs_params
-
     def process_rhs(self, compiler, connection):
         rhs, rhs_params = super().process_rhs(compiler, connection)
         # Treat None lookup values as null.
@@ -340,7 +331,13 @@ class KeyTransform(Transform):
     def as_sqlite(self, compiler, connection):
         lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
         json_path = compile_json_path(key_transforms)
-        return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,)
+        datatype_values = ','.join([
+            repr(datatype) for datatype in connection.ops.jsonfield_datatype_values
+        ])
+        return (
+            "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
+            "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
+        ) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
 
 
 class KeyTextTransform(KeyTransform):
@@ -408,7 +405,10 @@ class KeyTransformIn(lookups.In):
                     sql = sql % 'JSON_QUERY'
                 else:
                     sql = sql % 'JSON_VALUE'
-            elif connection.vendor in {'sqlite', 'mysql'}:
+            elif connection.vendor == 'mysql' or (
+                connection.vendor == 'sqlite' and
+                params[0] not in connection.ops.jsonfield_datatype_values
+            ):
                 sql = "JSON_EXTRACT(%s, '$')"
         if connection.vendor == 'mysql' and connection.mysql_is_mariadb:
             sql = 'JSON_UNQUOTE(%s)' % sql
@@ -416,15 +416,6 @@ class KeyTransformIn(lookups.In):
 
 
 class KeyTransformExact(JSONExact):
-    def process_lhs(self, compiler, connection):
-        lhs, lhs_params = super().process_lhs(compiler, connection)
-        if connection.vendor == 'sqlite':
-            rhs, rhs_params = super().process_rhs(compiler, connection)
-            if rhs == '%s' and rhs_params == ['null']:
-                lhs, *_ = self.lhs.preprocess_lhs(compiler, connection)
-                lhs = 'JSON_TYPE(%s, %%s)' % lhs
-        return lhs, lhs_params
-
     def process_rhs(self, compiler, connection):
         if isinstance(self.rhs, KeyTransform):
             return super(lookups.Exact, self).process_rhs(compiler, connection)
@@ -440,7 +431,12 @@ class KeyTransformExact(JSONExact):
                     func.append(sql % 'JSON_VALUE')
             rhs = rhs % tuple(func)
         elif connection.vendor == 'sqlite':
-            func = ["JSON_EXTRACT(%s, '$')" if value != 'null' else '%s' for value in rhs_params]
+            func = []
+            for value in rhs_params:
+                if value in connection.ops.jsonfield_datatype_values:
+                    func.append('%s')
+                else:
+                    func.append("JSON_EXTRACT(%s, '$')")
             rhs = rhs % tuple(func)
         return rhs, rhs_params
 

+ 0 - 12
docs/ref/models/querysets.txt

@@ -695,12 +695,6 @@ You can also refer to fields on related models with reverse relations through
    pronounced if you include multiple such fields in your ``values()`` query,
    in which case all possible combinations will be returned.
 
-.. admonition:: Boolean values for ``JSONField`` on SQLite
-
-    Due to the way the ``JSON_EXTRACT`` SQL function is implemented on SQLite,
-    ``values()`` will return ``1`` and ``0`` instead of ``True`` and ``False``
-    for :class:`~django.db.models.JSONField` key transforms.
-
 ``values_list()``
 ~~~~~~~~~~~~~~~~~
 
@@ -771,12 +765,6 @@ not having any author::
     >>> Entry.objects.values_list('authors')
     <QuerySet [('Noam Chomsky',), ('George Orwell',), (None,)]>
 
-.. admonition:: Boolean values for ``JSONField`` on SQLite
-
-    Due to the way the ``JSON_EXTRACT`` SQL function is implemented on SQLite,
-    ``values_list()`` will return ``1`` and ``0`` instead of ``True`` and
-    ``False`` for :class:`~django.db.models.JSONField` key transforms.
-
 ``dates()``
 ~~~~~~~~~~~
 

+ 10 - 0
tests/model_fields/test_jsonfield.py

@@ -808,6 +808,16 @@ class TestQuerying(TestCase):
             with self.subTest(lookup=lookup):
                 self.assertEqual(qs.values_list(lookup, flat=True).get(), expected)
 
+    def test_key_values_boolean(self):
+        qs = NullableJSONModel.objects.filter(value__h=True, value__i=False)
+        tests = [
+            ('value__h', True),
+            ('value__i', False),
+        ]
+        for lookup, expected in tests:
+            with self.subTest(lookup=lookup):
+                self.assertIs(qs.values_list(lookup, flat=True).get(), expected)
+
     @skipUnlessDBFeature('supports_json_field_contains')
     def test_key_contains(self):
         self.assertIs(NullableJSONModel.objects.filter(value__foo__contains='ar').exists(), False)