|
@@ -260,15 +260,6 @@ class CaseInsensitiveMixin:
|
|
|
class JSONExact(lookups.Exact):
|
|
|
can_use_none_as_rhs = True
|
|
|
|
|
|
- def process_lhs(self, compiler, connection):
|
|
|
- lhs, lhs_params = super().process_lhs(compiler, connection)
|
|
|
- if connection.vendor == 'sqlite':
|
|
|
- rhs, rhs_params = super().process_rhs(compiler, connection)
|
|
|
- if rhs == '%s' and rhs_params == [None]:
|
|
|
- # Use JSON_TYPE instead of JSON_EXTRACT for NULLs.
|
|
|
- lhs = "JSON_TYPE(%s, '$')" % lhs
|
|
|
- return lhs, lhs_params
|
|
|
-
|
|
|
def process_rhs(self, compiler, connection):
|
|
|
rhs, rhs_params = super().process_rhs(compiler, connection)
|
|
|
# Treat None lookup values as null.
|
|
@@ -340,7 +331,13 @@ class KeyTransform(Transform):
|
|
|
def as_sqlite(self, compiler, connection):
|
|
|
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
|
|
json_path = compile_json_path(key_transforms)
|
|
|
- return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,)
|
|
|
+ datatype_values = ','.join([
|
|
|
+ repr(datatype) for datatype in connection.ops.jsonfield_datatype_values
|
|
|
+ ])
|
|
|
+ return (
|
|
|
+ "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
|
|
|
+ "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
|
|
|
+ ) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
|
|
|
|
|
|
|
|
|
class KeyTextTransform(KeyTransform):
|
|
@@ -408,7 +405,10 @@ class KeyTransformIn(lookups.In):
|
|
|
sql = sql % 'JSON_QUERY'
|
|
|
else:
|
|
|
sql = sql % 'JSON_VALUE'
|
|
|
- elif connection.vendor in {'sqlite', 'mysql'}:
|
|
|
+ elif connection.vendor == 'mysql' or (
|
|
|
+ connection.vendor == 'sqlite' and
|
|
|
+ params[0] not in connection.ops.jsonfield_datatype_values
|
|
|
+ ):
|
|
|
sql = "JSON_EXTRACT(%s, '$')"
|
|
|
if connection.vendor == 'mysql' and connection.mysql_is_mariadb:
|
|
|
sql = 'JSON_UNQUOTE(%s)' % sql
|
|
@@ -416,15 +416,6 @@ class KeyTransformIn(lookups.In):
|
|
|
|
|
|
|
|
|
class KeyTransformExact(JSONExact):
|
|
|
- def process_lhs(self, compiler, connection):
|
|
|
- lhs, lhs_params = super().process_lhs(compiler, connection)
|
|
|
- if connection.vendor == 'sqlite':
|
|
|
- rhs, rhs_params = super().process_rhs(compiler, connection)
|
|
|
- if rhs == '%s' and rhs_params == ['null']:
|
|
|
- lhs, *_ = self.lhs.preprocess_lhs(compiler, connection)
|
|
|
- lhs = 'JSON_TYPE(%s, %%s)' % lhs
|
|
|
- return lhs, lhs_params
|
|
|
-
|
|
|
def process_rhs(self, compiler, connection):
|
|
|
if isinstance(self.rhs, KeyTransform):
|
|
|
return super(lookups.Exact, self).process_rhs(compiler, connection)
|
|
@@ -440,7 +431,12 @@ class KeyTransformExact(JSONExact):
|
|
|
func.append(sql % 'JSON_VALUE')
|
|
|
rhs = rhs % tuple(func)
|
|
|
elif connection.vendor == 'sqlite':
|
|
|
- func = ["JSON_EXTRACT(%s, '$')" if value != 'null' else '%s' for value in rhs_params]
|
|
|
+ func = []
|
|
|
+ for value in rhs_params:
|
|
|
+ if value in connection.ops.jsonfield_datatype_values:
|
|
|
+ func.append('%s')
|
|
|
+ else:
|
|
|
+ func.append("JSON_EXTRACT(%s, '$')")
|
|
|
rhs = rhs % tuple(func)
|
|
|
return rhs, rhs_params
|
|
|
|