123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510 |
- import json
- from django import forms
- from django.core import checks, exceptions
- from django.db import NotSupportedError, connections, router
- from django.db.models import lookups
- from django.db.models.lookups import PostgresOperatorLookup, Transform
- from django.utils.translation import gettext_lazy as _
- from . import Field
- from .mixins import CheckFieldDefaultMixin
- __all__ = ['JSONField']
- class JSONField(CheckFieldDefaultMixin, Field):
- empty_strings_allowed = False
- description = _('A JSON object')
- default_error_messages = {
- 'invalid': _('Value must be valid JSON.'),
- }
- _default_hint = ('dict', '{}')
- def __init__(
- self, verbose_name=None, name=None, encoder=None, decoder=None,
- **kwargs,
- ):
- if encoder and not callable(encoder):
- raise ValueError('The encoder parameter must be a callable object.')
- if decoder and not callable(decoder):
- raise ValueError('The decoder parameter must be a callable object.')
- self.encoder = encoder
- self.decoder = decoder
- super().__init__(verbose_name, name, **kwargs)
- def check(self, **kwargs):
- errors = super().check(**kwargs)
- databases = kwargs.get('databases') or []
- errors.extend(self._check_supported(databases))
- return errors
- def _check_supported(self, databases):
- errors = []
- for db in databases:
- if not router.allow_migrate_model(db, self.model):
- continue
- connection = connections[db]
- if not (
- 'supports_json_field' in self.model._meta.required_db_features or
- connection.features.supports_json_field
- ):
- errors.append(
- checks.Error(
- '%s does not support JSONFields.'
- % connection.display_name,
- obj=self.model,
- id='fields.E180',
- )
- )
- return errors
- def deconstruct(self):
- name, path, args, kwargs = super().deconstruct()
- if self.encoder is not None:
- kwargs['encoder'] = self.encoder
- if self.decoder is not None:
- kwargs['decoder'] = self.decoder
- return name, path, args, kwargs
- def from_db_value(self, value, expression, connection):
- if value is None:
- return value
- if connection.features.has_native_json_field and self.decoder is None:
- return value
- try:
- return json.loads(value, cls=self.decoder)
- except json.JSONDecodeError:
- return value
- def get_internal_type(self):
- return 'JSONField'
- def get_prep_value(self, value):
- if value is None:
- return value
- return json.dumps(value, cls=self.encoder)
- def get_transform(self, name):
- transform = super().get_transform(name)
- if transform:
- return transform
- return KeyTransformFactory(name)
- def select_format(self, compiler, sql, params):
- if (
- compiler.connection.features.has_native_json_field and
- self.decoder is not None
- ):
- return compiler.connection.ops.json_cast_text_sql(sql), params
- return super().select_format(compiler, sql, params)
- def validate(self, value, model_instance):
- super().validate(value, model_instance)
- try:
- json.dumps(value, cls=self.encoder)
- except TypeError:
- raise exceptions.ValidationError(
- self.error_messages['invalid'],
- code='invalid',
- params={'value': value},
- )
- def value_to_string(self, obj):
- return self.value_from_object(obj)
- def formfield(self, **kwargs):
- return super().formfield(**{
- 'form_class': forms.JSONField,
- 'encoder': self.encoder,
- 'decoder': self.decoder,
- **kwargs,
- })
- def compile_json_path(key_transforms, include_root=True):
- path = ['$'] if include_root else []
- for key_transform in key_transforms:
- try:
- num = int(key_transform)
- except ValueError: # non-integer
- path.append('.')
- path.append(json.dumps(key_transform))
- else:
- path.append('[%s]' % num)
- return ''.join(path)
- class DataContains(PostgresOperatorLookup):
- lookup_name = 'contains'
- postgres_operator = '@>'
- def as_sql(self, compiler, connection):
- if not connection.features.supports_json_field_contains:
- raise NotSupportedError(
- 'contains lookup is not supported on this database backend.'
- )
- lhs, lhs_params = self.process_lhs(compiler, connection)
- rhs, rhs_params = self.process_rhs(compiler, connection)
- params = tuple(lhs_params) + tuple(rhs_params)
- return 'JSON_CONTAINS(%s, %s)' % (lhs, rhs), params
- class ContainedBy(PostgresOperatorLookup):
- lookup_name = 'contained_by'
- postgres_operator = '<@'
- def as_sql(self, compiler, connection):
- if not connection.features.supports_json_field_contains:
- raise NotSupportedError(
- 'contained_by lookup is not supported on this database backend.'
- )
- lhs, lhs_params = self.process_lhs(compiler, connection)
- rhs, rhs_params = self.process_rhs(compiler, connection)
- params = tuple(rhs_params) + tuple(lhs_params)
- return 'JSON_CONTAINS(%s, %s)' % (rhs, lhs), params
- class HasKeyLookup(PostgresOperatorLookup):
- logical_operator = None
- def as_sql(self, compiler, connection, template=None):
- # Process JSON path from the left-hand side.
- if isinstance(self.lhs, KeyTransform):
- lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(compiler, connection)
- lhs_json_path = compile_json_path(lhs_key_transforms)
- else:
- lhs, lhs_params = self.process_lhs(compiler, connection)
- lhs_json_path = '$'
- sql = template % lhs
- # Process JSON path from the right-hand side.
- rhs = self.rhs
- rhs_params = []
- if not isinstance(rhs, (list, tuple)):
- rhs = [rhs]
- for key in rhs:
- if isinstance(key, KeyTransform):
- *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
- else:
- rhs_key_transforms = [key]
- rhs_params.append('%s%s' % (
- lhs_json_path,
- compile_json_path(rhs_key_transforms, include_root=False),
- ))
- # Add condition for each key.
- if self.logical_operator:
- sql = '(%s)' % self.logical_operator.join([sql] * len(rhs_params))
- return sql, tuple(lhs_params) + tuple(rhs_params)
- def as_mysql(self, compiler, connection):
- return self.as_sql(compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)")
- def as_oracle(self, compiler, connection):
- sql, params = self.as_sql(compiler, connection, template="JSON_EXISTS(%s, '%%s')")
- # Add paths directly into SQL because path expressions cannot be passed
- # as bind variables on Oracle.
- return sql % tuple(params), []
- def as_postgresql(self, compiler, connection):
- if isinstance(self.rhs, KeyTransform):
- *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
- for key in rhs_key_transforms[:-1]:
- self.lhs = KeyTransform(key, self.lhs)
- self.rhs = rhs_key_transforms[-1]
- return super().as_postgresql(compiler, connection)
- def as_sqlite(self, compiler, connection):
- return self.as_sql(compiler, connection, template='JSON_TYPE(%s, %%s) IS NOT NULL')
- class HasKey(HasKeyLookup):
- lookup_name = 'has_key'
- postgres_operator = '?'
- prepare_rhs = False
- class HasKeys(HasKeyLookup):
- lookup_name = 'has_keys'
- postgres_operator = '?&'
- logical_operator = ' AND '
- def get_prep_lookup(self):
- return [str(item) for item in self.rhs]
- class HasAnyKeys(HasKeys):
- lookup_name = 'has_any_keys'
- postgres_operator = '?|'
- logical_operator = ' OR '
- class JSONExact(lookups.Exact):
- can_use_none_as_rhs = True
- def process_lhs(self, compiler, connection):
- lhs, lhs_params = super().process_lhs(compiler, connection)
- if connection.vendor == 'sqlite':
- rhs, rhs_params = super().process_rhs(compiler, connection)
- if rhs == '%s' and rhs_params == [None]:
- # Use JSON_TYPE instead of JSON_EXTRACT for NULLs.
- lhs = "JSON_TYPE(%s, '$')" % lhs
- return lhs, lhs_params
- def process_rhs(self, compiler, connection):
- rhs, rhs_params = super().process_rhs(compiler, connection)
- # Treat None lookup values as null.
- if rhs == '%s' and rhs_params == [None]:
- rhs_params = ['null']
- if connection.vendor == 'mysql':
- func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
- rhs = rhs % tuple(func)
- return rhs, rhs_params
- JSONField.register_lookup(DataContains)
- JSONField.register_lookup(ContainedBy)
- JSONField.register_lookup(HasKey)
- JSONField.register_lookup(HasKeys)
- JSONField.register_lookup(HasAnyKeys)
- JSONField.register_lookup(JSONExact)
- class KeyTransform(Transform):
- postgres_operator = '->'
- postgres_nested_operator = '#>'
- def __init__(self, key_name, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.key_name = str(key_name)
- def preprocess_lhs(self, compiler, connection, lhs_only=False):
- if not lhs_only:
- key_transforms = [self.key_name]
- previous = self.lhs
- while isinstance(previous, KeyTransform):
- if not lhs_only:
- key_transforms.insert(0, previous.key_name)
- previous = previous.lhs
- lhs, params = compiler.compile(previous)
- if connection.vendor == 'oracle':
- # Escape string-formatting.
- key_transforms = [key.replace('%', '%%') for key in key_transforms]
- return (lhs, params, key_transforms) if not lhs_only else (lhs, params)
- def as_mysql(self, compiler, connection):
- lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
- json_path = compile_json_path(key_transforms)
- return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,)
- def as_oracle(self, compiler, connection):
- lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
- json_path = compile_json_path(key_transforms)
- return (
- "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))" %
- ((lhs, json_path) * 2)
- ), tuple(params) * 2
- def as_postgresql(self, compiler, connection):
- lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
- if len(key_transforms) > 1:
- return '(%s %s %%s)' % (lhs, self.postgres_nested_operator), params + [key_transforms]
- try:
- lookup = int(self.key_name)
- except ValueError:
- lookup = self.key_name
- return '(%s %s %%s)' % (lhs, self.postgres_operator), tuple(params) + (lookup,)
- def as_sqlite(self, compiler, connection):
- lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
- json_path = compile_json_path(key_transforms)
- return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,)
- class KeyTextTransform(KeyTransform):
- postgres_operator = '->>'
- postgres_nested_operator = '#>>'
- class KeyTransformTextLookupMixin:
- """
- Mixin for combining with a lookup expecting a text lhs from a JSONField
- key lookup. On PostgreSQL, make use of the ->> operator instead of casting
- key values to text and performing the lookup on the resulting
- representation.
- """
- def __init__(self, key_transform, *args, **kwargs):
- if not isinstance(key_transform, KeyTransform):
- raise TypeError(
- 'Transform should be an instance of KeyTransform in order to '
- 'use this lookup.'
- )
- key_text_transform = KeyTextTransform(
- key_transform.key_name, *key_transform.source_expressions,
- **key_transform.extra,
- )
- super().__init__(key_text_transform, *args, **kwargs)
- class CaseInsensitiveMixin:
- """
- Mixin to allow case-insensitive comparison of JSON values on MySQL.
- MySQL handles strings used in JSON context using the utf8mb4_bin collation.
- Because utf8mb4_bin is a binary collation, comparison of JSON values is
- case-sensitive.
- """
- def process_lhs(self, compiler, connection):
- lhs, lhs_params = super().process_lhs(compiler, connection)
- if connection.vendor == 'mysql':
- return 'LOWER(%s)' % lhs, lhs_params
- return lhs, lhs_params
- def process_rhs(self, compiler, connection):
- rhs, rhs_params = super().process_rhs(compiler, connection)
- if connection.vendor == 'mysql':
- return 'LOWER(%s)' % rhs, rhs_params
- return rhs, rhs_params
- class KeyTransformIsNull(lookups.IsNull):
- # key__isnull=False is the same as has_key='key'
- def as_oracle(self, compiler, connection):
- if not self.rhs:
- return HasKey(self.lhs.lhs, self.lhs.key_name).as_oracle(compiler, connection)
- return super().as_sql(compiler, connection)
- def as_sqlite(self, compiler, connection):
- if not self.rhs:
- return HasKey(self.lhs.lhs, self.lhs.key_name).as_sqlite(compiler, connection)
- return super().as_sql(compiler, connection)
- class KeyTransformExact(JSONExact):
- def process_lhs(self, compiler, connection):
- lhs, lhs_params = super().process_lhs(compiler, connection)
- if connection.vendor == 'sqlite':
- rhs, rhs_params = super().process_rhs(compiler, connection)
- if rhs == '%s' and rhs_params == ['null']:
- lhs, _ = self.lhs.preprocess_lhs(compiler, connection, lhs_only=True)
- lhs = 'JSON_TYPE(%s, %%s)' % lhs
- return lhs, lhs_params
- def process_rhs(self, compiler, connection):
- if isinstance(self.rhs, KeyTransform):
- return super(lookups.Exact, self).process_rhs(compiler, connection)
- rhs, rhs_params = super().process_rhs(compiler, connection)
- if connection.vendor == 'oracle':
- func = []
- for value in rhs_params:
- value = json.loads(value)
- function = 'JSON_QUERY' if isinstance(value, (list, dict)) else 'JSON_VALUE'
- func.append("%s('%s', '$.value')" % (
- function,
- json.dumps({'value': value}),
- ))
- rhs = rhs % tuple(func)
- rhs_params = []
- elif connection.vendor == 'sqlite':
- func = ["JSON_EXTRACT(%s, '$')" if value != 'null' else '%s' for value in rhs_params]
- rhs = rhs % tuple(func)
- return rhs, rhs_params
- def as_oracle(self, compiler, connection):
- rhs, rhs_params = super().process_rhs(compiler, connection)
- if rhs_params == ['null']:
- # Field has key and it's NULL.
- has_key_expr = HasKey(self.lhs.lhs, self.lhs.key_name)
- has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
- is_null_expr = self.lhs.get_lookup('isnull')(self.lhs, True)
- is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
- return (
- '%s AND %s' % (has_key_sql, is_null_sql),
- tuple(has_key_params) + tuple(is_null_params),
- )
- return super().as_sql(compiler, connection)
- class KeyTransformIExact(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact):
- pass
- class KeyTransformIContains(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains):
- pass
- class KeyTransformContains(KeyTransformTextLookupMixin, lookups.Contains):
- pass
- class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
- pass
- class KeyTransformIStartsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith):
- pass
- class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
- pass
- class KeyTransformIEndsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith):
- pass
- class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
- pass
- class KeyTransformIRegex(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex):
- pass
- class KeyTransformNumericLookupMixin:
- def process_rhs(self, compiler, connection):
- rhs, rhs_params = super().process_rhs(compiler, connection)
- if not connection.features.has_native_json_field:
- rhs_params = [json.loads(value) for value in rhs_params]
- return rhs, rhs_params
- class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
- pass
- class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
- pass
- class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
- pass
- class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
- pass
- KeyTransform.register_lookup(KeyTransformExact)
- KeyTransform.register_lookup(KeyTransformIExact)
- KeyTransform.register_lookup(KeyTransformIsNull)
- KeyTransform.register_lookup(KeyTransformContains)
- KeyTransform.register_lookup(KeyTransformIContains)
- KeyTransform.register_lookup(KeyTransformStartsWith)
- KeyTransform.register_lookup(KeyTransformIStartsWith)
- KeyTransform.register_lookup(KeyTransformEndsWith)
- KeyTransform.register_lookup(KeyTransformIEndsWith)
- KeyTransform.register_lookup(KeyTransformRegex)
- KeyTransform.register_lookup(KeyTransformIRegex)
- KeyTransform.register_lookup(KeyTransformLt)
- KeyTransform.register_lookup(KeyTransformLte)
- KeyTransform.register_lookup(KeyTransformGt)
- KeyTransform.register_lookup(KeyTransformGte)
- class KeyTransformFactory:
- def __init__(self, key_name):
- self.key_name = key_name
- def __call__(self, *args, **kwargs):
- return KeyTransform(self.key_name, *args, **kwargs)
|