123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600 |
- import json
- from django import forms
- from django.core import checks, exceptions
- from django.db import NotSupportedError, connections, router
- from django.db.models import lookups
- from django.db.models.constants import LOOKUP_SEP
- from django.db.models.fields import TextField
- from django.db.models.lookups import PostgresOperatorLookup, Transform
- from django.utils.translation import gettext_lazy as _
- from . import Field
- from .mixins import CheckFieldDefaultMixin
- __all__ = ["JSONField"]
- class JSONField(CheckFieldDefaultMixin, Field):
- empty_strings_allowed = False
- description = _("A JSON object")
- default_error_messages = {
- "invalid": _("Value must be valid JSON."),
- }
- _default_hint = ("dict", "{}")
- def __init__(
- self,
- verbose_name=None,
- name=None,
- encoder=None,
- decoder=None,
- **kwargs,
- ):
- if encoder and not callable(encoder):
- raise ValueError("The encoder parameter must be a callable object.")
- if decoder and not callable(decoder):
- raise ValueError("The decoder parameter must be a callable object.")
- self.encoder = encoder
- self.decoder = decoder
- super().__init__(verbose_name, name, **kwargs)
- def check(self, **kwargs):
- errors = super().check(**kwargs)
- databases = kwargs.get("databases") or []
- errors.extend(self._check_supported(databases))
- return errors
- def _check_supported(self, databases):
- errors = []
- for db in databases:
- if not router.allow_migrate_model(db, self.model):
- continue
- connection = connections[db]
- if (
- self.model._meta.required_db_vendor
- and self.model._meta.required_db_vendor != connection.vendor
- ):
- continue
- if not (
- "supports_json_field" in self.model._meta.required_db_features
- or connection.features.supports_json_field
- ):
- errors.append(
- checks.Error(
- "%s does not support JSONFields." % connection.display_name,
- obj=self.model,
- id="fields.E180",
- )
- )
- return errors
- def deconstruct(self):
- name, path, args, kwargs = super().deconstruct()
- if self.encoder is not None:
- kwargs["encoder"] = self.encoder
- if self.decoder is not None:
- kwargs["decoder"] = self.decoder
- return name, path, args, kwargs
- def from_db_value(self, value, expression, connection):
- if value is None:
- return value
- # Some backends (SQLite at least) extract non-string values in their
- # SQL datatypes.
- if isinstance(expression, KeyTransform) and not isinstance(value, str):
- return value
- try:
- return json.loads(value, cls=self.decoder)
- except json.JSONDecodeError:
- return value
- def get_internal_type(self):
- return "JSONField"
- def get_prep_value(self, value):
- if value is None:
- return value
- return json.dumps(value, cls=self.encoder)
- def get_transform(self, name):
- transform = super().get_transform(name)
- if transform:
- return transform
- return KeyTransformFactory(name)
- def validate(self, value, model_instance):
- super().validate(value, model_instance)
- try:
- json.dumps(value, cls=self.encoder)
- except TypeError:
- raise exceptions.ValidationError(
- self.error_messages["invalid"],
- code="invalid",
- params={"value": value},
- )
- def value_to_string(self, obj):
- return self.value_from_object(obj)
- def formfield(self, **kwargs):
- return super().formfield(
- **{
- "form_class": forms.JSONField,
- "encoder": self.encoder,
- "decoder": self.decoder,
- **kwargs,
- }
- )
- def compile_json_path(key_transforms, include_root=True):
- path = ["$"] if include_root else []
- for key_transform in key_transforms:
- try:
- num = int(key_transform)
- except ValueError: # non-integer
- path.append(".")
- path.append(json.dumps(key_transform))
- else:
- path.append("[%s]" % num)
- return "".join(path)
- class DataContains(PostgresOperatorLookup):
- lookup_name = "contains"
- postgres_operator = "@>"
- def as_sql(self, compiler, connection):
- if not connection.features.supports_json_field_contains:
- raise NotSupportedError(
- "contains lookup is not supported on this database backend."
- )
- lhs, lhs_params = self.process_lhs(compiler, connection)
- rhs, rhs_params = self.process_rhs(compiler, connection)
- params = tuple(lhs_params) + tuple(rhs_params)
- return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
- class ContainedBy(PostgresOperatorLookup):
- lookup_name = "contained_by"
- postgres_operator = "<@"
- def as_sql(self, compiler, connection):
- if not connection.features.supports_json_field_contains:
- raise NotSupportedError(
- "contained_by lookup is not supported on this database backend."
- )
- lhs, lhs_params = self.process_lhs(compiler, connection)
- rhs, rhs_params = self.process_rhs(compiler, connection)
- params = tuple(rhs_params) + tuple(lhs_params)
- return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
- class HasKeyLookup(PostgresOperatorLookup):
- logical_operator = None
- def compile_json_path_final_key(self, key_transform):
- # Compile the final key without interpreting ints as array elements.
- return ".%s" % json.dumps(key_transform)
- def as_sql(self, compiler, connection, template=None):
- # Process JSON path from the left-hand side.
- if isinstance(self.lhs, KeyTransform):
- lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
- compiler, connection
- )
- lhs_json_path = compile_json_path(lhs_key_transforms)
- else:
- lhs, lhs_params = self.process_lhs(compiler, connection)
- lhs_json_path = "$"
- sql = template % lhs
- # Process JSON path from the right-hand side.
- rhs = self.rhs
- rhs_params = []
- if not isinstance(rhs, (list, tuple)):
- rhs = [rhs]
- for key in rhs:
- if isinstance(key, KeyTransform):
- *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
- else:
- rhs_key_transforms = [key]
- *rhs_key_transforms, final_key = rhs_key_transforms
- rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
- rhs_json_path += self.compile_json_path_final_key(final_key)
- rhs_params.append(lhs_json_path + rhs_json_path)
- # Add condition for each key.
- if self.logical_operator:
- sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params))
- return sql, tuple(lhs_params) + tuple(rhs_params)
- def as_mysql(self, compiler, connection):
- return self.as_sql(
- compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
- )
- def as_oracle(self, compiler, connection):
- sql, params = self.as_sql(
- compiler, connection, template="JSON_EXISTS(%s, '%%s')"
- )
- # Add paths directly into SQL because path expressions cannot be passed
- # as bind variables on Oracle.
- return sql % tuple(params), []
- def as_postgresql(self, compiler, connection):
- if isinstance(self.rhs, KeyTransform):
- *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
- for key in rhs_key_transforms[:-1]:
- self.lhs = KeyTransform(key, self.lhs)
- self.rhs = rhs_key_transforms[-1]
- return super().as_postgresql(compiler, connection)
- def as_sqlite(self, compiler, connection):
- return self.as_sql(
- compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
- )
- class HasKey(HasKeyLookup):
- lookup_name = "has_key"
- postgres_operator = "?"
- prepare_rhs = False
- class HasKeys(HasKeyLookup):
- lookup_name = "has_keys"
- postgres_operator = "?&"
- logical_operator = " AND "
- def get_prep_lookup(self):
- return [str(item) for item in self.rhs]
- class HasAnyKeys(HasKeys):
- lookup_name = "has_any_keys"
- postgres_operator = "?|"
- logical_operator = " OR "
- class HasKeyOrArrayIndex(HasKey):
- def compile_json_path_final_key(self, key_transform):
- return compile_json_path([key_transform], include_root=False)
- class CaseInsensitiveMixin:
- """
- Mixin to allow case-insensitive comparison of JSON values on MySQL.
- MySQL handles strings used in JSON context using the utf8mb4_bin collation.
- Because utf8mb4_bin is a binary collation, comparison of JSON values is
- case-sensitive.
- """
- def process_lhs(self, compiler, connection):
- lhs, lhs_params = super().process_lhs(compiler, connection)
- if connection.vendor == "mysql":
- return "LOWER(%s)" % lhs, lhs_params
- return lhs, lhs_params
- def process_rhs(self, compiler, connection):
- rhs, rhs_params = super().process_rhs(compiler, connection)
- if connection.vendor == "mysql":
- return "LOWER(%s)" % rhs, rhs_params
- return rhs, rhs_params
- class JSONExact(lookups.Exact):
- can_use_none_as_rhs = True
- def process_rhs(self, compiler, connection):
- rhs, rhs_params = super().process_rhs(compiler, connection)
- # Treat None lookup values as null.
- if rhs == "%s" and rhs_params == [None]:
- rhs_params = ["null"]
- if connection.vendor == "mysql":
- func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
- rhs = rhs % tuple(func)
- return rhs, rhs_params
- class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
- pass
- JSONField.register_lookup(DataContains)
- JSONField.register_lookup(ContainedBy)
- JSONField.register_lookup(HasKey)
- JSONField.register_lookup(HasKeys)
- JSONField.register_lookup(HasAnyKeys)
- JSONField.register_lookup(JSONExact)
- JSONField.register_lookup(JSONIContains)
- class KeyTransform(Transform):
- postgres_operator = "->"
- postgres_nested_operator = "#>"
- def __init__(self, key_name, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.key_name = str(key_name)
- def preprocess_lhs(self, compiler, connection):
- key_transforms = [self.key_name]
- previous = self.lhs
- while isinstance(previous, KeyTransform):
- key_transforms.insert(0, previous.key_name)
- previous = previous.lhs
- lhs, params = compiler.compile(previous)
- if connection.vendor == "oracle":
- # Escape string-formatting.
- key_transforms = [key.replace("%", "%%") for key in key_transforms]
- return lhs, params, key_transforms
- def as_mysql(self, compiler, connection):
- lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
- json_path = compile_json_path(key_transforms)
- return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
- def as_oracle(self, compiler, connection):
- lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
- json_path = compile_json_path(key_transforms)
- return (
- "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))"
- % ((lhs, json_path) * 2)
- ), tuple(params) * 2
- def as_postgresql(self, compiler, connection):
- lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
- if len(key_transforms) > 1:
- sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
- return sql, tuple(params) + (key_transforms,)
- try:
- lookup = int(self.key_name)
- except ValueError:
- lookup = self.key_name
- return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
- def as_sqlite(self, compiler, connection):
- lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
- json_path = compile_json_path(key_transforms)
- datatype_values = ",".join(
- [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
- )
- return (
- "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
- "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
- ) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
- class KeyTextTransform(KeyTransform):
- postgres_operator = "->>"
- postgres_nested_operator = "#>>"
- output_field = TextField()
- def as_mysql(self, compiler, connection):
- if connection.mysql_is_mariadb:
- # MariaDB doesn't support -> and ->> operators (see MDEV-13594).
- sql, params = super().as_mysql(compiler, connection)
- return "JSON_UNQUOTE(%s)" % sql, params
- else:
- lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
- json_path = compile_json_path(key_transforms)
- return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,)
- @classmethod
- def from_lookup(cls, lookup):
- transform, *keys = lookup.split(LOOKUP_SEP)
- if not keys:
- raise ValueError("Lookup must contain key or index transforms.")
- for key in keys:
- transform = cls(key, transform)
- return transform
- KT = KeyTextTransform.from_lookup
- class KeyTransformTextLookupMixin:
- """
- Mixin for combining with a lookup expecting a text lhs from a JSONField
- key lookup. On PostgreSQL, make use of the ->> operator instead of casting
- key values to text and performing the lookup on the resulting
- representation.
- """
- def __init__(self, key_transform, *args, **kwargs):
- if not isinstance(key_transform, KeyTransform):
- raise TypeError(
- "Transform should be an instance of KeyTransform in order to "
- "use this lookup."
- )
- key_text_transform = KeyTextTransform(
- key_transform.key_name,
- *key_transform.source_expressions,
- **key_transform.extra,
- )
- super().__init__(key_text_transform, *args, **kwargs)
- class KeyTransformIsNull(lookups.IsNull):
- # key__isnull=False is the same as has_key='key'
- def as_oracle(self, compiler, connection):
- sql, params = HasKeyOrArrayIndex(
- self.lhs.lhs,
- self.lhs.key_name,
- ).as_oracle(compiler, connection)
- if not self.rhs:
- return sql, params
- # Column doesn't have a key or IS NULL.
- lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
- return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
- def as_sqlite(self, compiler, connection):
- template = "JSON_TYPE(%s, %%s) IS NULL"
- if not self.rhs:
- template = "JSON_TYPE(%s, %%s) IS NOT NULL"
- return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
- compiler,
- connection,
- template=template,
- )
- class KeyTransformIn(lookups.In):
- def resolve_expression_parameter(self, compiler, connection, sql, param):
- sql, params = super().resolve_expression_parameter(
- compiler,
- connection,
- sql,
- param,
- )
- if (
- not hasattr(param, "as_sql")
- and not connection.features.has_native_json_field
- ):
- if connection.vendor == "oracle":
- value = json.loads(param)
- sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
- if isinstance(value, (list, dict)):
- sql = sql % "JSON_QUERY"
- else:
- sql = sql % "JSON_VALUE"
- elif connection.vendor == "mysql" or (
- connection.vendor == "sqlite"
- and params[0] not in connection.ops.jsonfield_datatype_values
- ):
- sql = "JSON_EXTRACT(%s, '$')"
- if connection.vendor == "mysql" and connection.mysql_is_mariadb:
- sql = "JSON_UNQUOTE(%s)" % sql
- return sql, params
- class KeyTransformExact(JSONExact):
- def process_rhs(self, compiler, connection):
- if isinstance(self.rhs, KeyTransform):
- return super(lookups.Exact, self).process_rhs(compiler, connection)
- rhs, rhs_params = super().process_rhs(compiler, connection)
- if connection.vendor == "oracle":
- func = []
- sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
- for value in rhs_params:
- value = json.loads(value)
- if isinstance(value, (list, dict)):
- func.append(sql % "JSON_QUERY")
- else:
- func.append(sql % "JSON_VALUE")
- rhs = rhs % tuple(func)
- elif connection.vendor == "sqlite":
- func = []
- for value in rhs_params:
- if value in connection.ops.jsonfield_datatype_values:
- func.append("%s")
- else:
- func.append("JSON_EXTRACT(%s, '$')")
- rhs = rhs % tuple(func)
- return rhs, rhs_params
- def as_oracle(self, compiler, connection):
- rhs, rhs_params = super().process_rhs(compiler, connection)
- if rhs_params == ["null"]:
- # Field has key and it's NULL.
- has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name)
- has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
- is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
- is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
- return (
- "%s AND %s" % (has_key_sql, is_null_sql),
- tuple(has_key_params) + tuple(is_null_params),
- )
- return super().as_sql(compiler, connection)
- class KeyTransformIExact(
- CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
- ):
- pass
- class KeyTransformIContains(
- CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
- ):
- pass
- class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
- pass
- class KeyTransformIStartsWith(
- CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
- ):
- pass
- class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
- pass
- class KeyTransformIEndsWith(
- CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
- ):
- pass
- class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
- pass
- class KeyTransformIRegex(
- CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
- ):
- pass
- class KeyTransformNumericLookupMixin:
- def process_rhs(self, compiler, connection):
- rhs, rhs_params = super().process_rhs(compiler, connection)
- if not connection.features.has_native_json_field:
- rhs_params = [json.loads(value) for value in rhs_params]
- return rhs, rhs_params
- class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
- pass
- class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
- pass
- class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
- pass
- class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
- pass
- KeyTransform.register_lookup(KeyTransformIn)
- KeyTransform.register_lookup(KeyTransformExact)
- KeyTransform.register_lookup(KeyTransformIExact)
- KeyTransform.register_lookup(KeyTransformIsNull)
- KeyTransform.register_lookup(KeyTransformIContains)
- KeyTransform.register_lookup(KeyTransformStartsWith)
- KeyTransform.register_lookup(KeyTransformIStartsWith)
- KeyTransform.register_lookup(KeyTransformEndsWith)
- KeyTransform.register_lookup(KeyTransformIEndsWith)
- KeyTransform.register_lookup(KeyTransformRegex)
- KeyTransform.register_lookup(KeyTransformIRegex)
- KeyTransform.register_lookup(KeyTransformLt)
- KeyTransform.register_lookup(KeyTransformLte)
- KeyTransform.register_lookup(KeyTransformGt)
- KeyTransform.register_lookup(KeyTransformGte)
- class KeyTransformFactory:
- def __init__(self, key_name):
- self.key_name = key_name
- def __call__(self, *args, **kwargs):
- return KeyTransform(self.key_name, *args, **kwargs)
|