json.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. import json
  2. from django import forms
  3. from django.core import checks, exceptions
  4. from django.db import NotSupportedError, connections, router
  5. from django.db.models import lookups
  6. from django.db.models.constants import LOOKUP_SEP
  7. from django.db.models.fields import TextField
  8. from django.db.models.lookups import PostgresOperatorLookup, Transform
  9. from django.utils.translation import gettext_lazy as _
  10. from . import Field
  11. from .mixins import CheckFieldDefaultMixin
  12. __all__ = ["JSONField"]
  13. class JSONField(CheckFieldDefaultMixin, Field):
  14. empty_strings_allowed = False
  15. description = _("A JSON object")
  16. default_error_messages = {
  17. "invalid": _("Value must be valid JSON."),
  18. }
  19. _default_hint = ("dict", "{}")
  20. def __init__(
  21. self,
  22. verbose_name=None,
  23. name=None,
  24. encoder=None,
  25. decoder=None,
  26. **kwargs,
  27. ):
  28. if encoder and not callable(encoder):
  29. raise ValueError("The encoder parameter must be a callable object.")
  30. if decoder and not callable(decoder):
  31. raise ValueError("The decoder parameter must be a callable object.")
  32. self.encoder = encoder
  33. self.decoder = decoder
  34. super().__init__(verbose_name, name, **kwargs)
  35. def check(self, **kwargs):
  36. errors = super().check(**kwargs)
  37. databases = kwargs.get("databases") or []
  38. errors.extend(self._check_supported(databases))
  39. return errors
  40. def _check_supported(self, databases):
  41. errors = []
  42. for db in databases:
  43. if not router.allow_migrate_model(db, self.model):
  44. continue
  45. connection = connections[db]
  46. if (
  47. self.model._meta.required_db_vendor
  48. and self.model._meta.required_db_vendor != connection.vendor
  49. ):
  50. continue
  51. if not (
  52. "supports_json_field" in self.model._meta.required_db_features
  53. or connection.features.supports_json_field
  54. ):
  55. errors.append(
  56. checks.Error(
  57. "%s does not support JSONFields." % connection.display_name,
  58. obj=self.model,
  59. id="fields.E180",
  60. )
  61. )
  62. return errors
  63. def deconstruct(self):
  64. name, path, args, kwargs = super().deconstruct()
  65. if self.encoder is not None:
  66. kwargs["encoder"] = self.encoder
  67. if self.decoder is not None:
  68. kwargs["decoder"] = self.decoder
  69. return name, path, args, kwargs
  70. def from_db_value(self, value, expression, connection):
  71. if value is None:
  72. return value
  73. # Some backends (SQLite at least) extract non-string values in their
  74. # SQL datatypes.
  75. if isinstance(expression, KeyTransform) and not isinstance(value, str):
  76. return value
  77. try:
  78. return json.loads(value, cls=self.decoder)
  79. except json.JSONDecodeError:
  80. return value
  81. def get_internal_type(self):
  82. return "JSONField"
  83. def get_prep_value(self, value):
  84. if value is None:
  85. return value
  86. return json.dumps(value, cls=self.encoder)
  87. def get_transform(self, name):
  88. transform = super().get_transform(name)
  89. if transform:
  90. return transform
  91. return KeyTransformFactory(name)
  92. def validate(self, value, model_instance):
  93. super().validate(value, model_instance)
  94. try:
  95. json.dumps(value, cls=self.encoder)
  96. except TypeError:
  97. raise exceptions.ValidationError(
  98. self.error_messages["invalid"],
  99. code="invalid",
  100. params={"value": value},
  101. )
  102. def value_to_string(self, obj):
  103. return self.value_from_object(obj)
  104. def formfield(self, **kwargs):
  105. return super().formfield(
  106. **{
  107. "form_class": forms.JSONField,
  108. "encoder": self.encoder,
  109. "decoder": self.decoder,
  110. **kwargs,
  111. }
  112. )
  113. def compile_json_path(key_transforms, include_root=True):
  114. path = ["$"] if include_root else []
  115. for key_transform in key_transforms:
  116. try:
  117. num = int(key_transform)
  118. except ValueError: # non-integer
  119. path.append(".")
  120. path.append(json.dumps(key_transform))
  121. else:
  122. path.append("[%s]" % num)
  123. return "".join(path)
  124. class DataContains(PostgresOperatorLookup):
  125. lookup_name = "contains"
  126. postgres_operator = "@>"
  127. def as_sql(self, compiler, connection):
  128. if not connection.features.supports_json_field_contains:
  129. raise NotSupportedError(
  130. "contains lookup is not supported on this database backend."
  131. )
  132. lhs, lhs_params = self.process_lhs(compiler, connection)
  133. rhs, rhs_params = self.process_rhs(compiler, connection)
  134. params = tuple(lhs_params) + tuple(rhs_params)
  135. return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
  136. class ContainedBy(PostgresOperatorLookup):
  137. lookup_name = "contained_by"
  138. postgres_operator = "<@"
  139. def as_sql(self, compiler, connection):
  140. if not connection.features.supports_json_field_contains:
  141. raise NotSupportedError(
  142. "contained_by lookup is not supported on this database backend."
  143. )
  144. lhs, lhs_params = self.process_lhs(compiler, connection)
  145. rhs, rhs_params = self.process_rhs(compiler, connection)
  146. params = tuple(rhs_params) + tuple(lhs_params)
  147. return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
  148. class HasKeyLookup(PostgresOperatorLookup):
  149. logical_operator = None
  150. def compile_json_path_final_key(self, key_transform):
  151. # Compile the final key without interpreting ints as array elements.
  152. return ".%s" % json.dumps(key_transform)
  153. def as_sql(self, compiler, connection, template=None):
  154. # Process JSON path from the left-hand side.
  155. if isinstance(self.lhs, KeyTransform):
  156. lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
  157. compiler, connection
  158. )
  159. lhs_json_path = compile_json_path(lhs_key_transforms)
  160. else:
  161. lhs, lhs_params = self.process_lhs(compiler, connection)
  162. lhs_json_path = "$"
  163. sql = template % lhs
  164. # Process JSON path from the right-hand side.
  165. rhs = self.rhs
  166. rhs_params = []
  167. if not isinstance(rhs, (list, tuple)):
  168. rhs = [rhs]
  169. for key in rhs:
  170. if isinstance(key, KeyTransform):
  171. *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
  172. else:
  173. rhs_key_transforms = [key]
  174. *rhs_key_transforms, final_key = rhs_key_transforms
  175. rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
  176. rhs_json_path += self.compile_json_path_final_key(final_key)
  177. rhs_params.append(lhs_json_path + rhs_json_path)
  178. # Add condition for each key.
  179. if self.logical_operator:
  180. sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params))
  181. return sql, tuple(lhs_params) + tuple(rhs_params)
  182. def as_mysql(self, compiler, connection):
  183. return self.as_sql(
  184. compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
  185. )
  186. def as_oracle(self, compiler, connection):
  187. sql, params = self.as_sql(
  188. compiler, connection, template="JSON_EXISTS(%s, '%%s')"
  189. )
  190. # Add paths directly into SQL because path expressions cannot be passed
  191. # as bind variables on Oracle.
  192. return sql % tuple(params), []
  193. def as_postgresql(self, compiler, connection):
  194. if isinstance(self.rhs, KeyTransform):
  195. *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
  196. for key in rhs_key_transforms[:-1]:
  197. self.lhs = KeyTransform(key, self.lhs)
  198. self.rhs = rhs_key_transforms[-1]
  199. return super().as_postgresql(compiler, connection)
  200. def as_sqlite(self, compiler, connection):
  201. return self.as_sql(
  202. compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
  203. )
  204. class HasKey(HasKeyLookup):
  205. lookup_name = "has_key"
  206. postgres_operator = "?"
  207. prepare_rhs = False
  208. class HasKeys(HasKeyLookup):
  209. lookup_name = "has_keys"
  210. postgres_operator = "?&"
  211. logical_operator = " AND "
  212. def get_prep_lookup(self):
  213. return [str(item) for item in self.rhs]
  214. class HasAnyKeys(HasKeys):
  215. lookup_name = "has_any_keys"
  216. postgres_operator = "?|"
  217. logical_operator = " OR "
  218. class HasKeyOrArrayIndex(HasKey):
  219. def compile_json_path_final_key(self, key_transform):
  220. return compile_json_path([key_transform], include_root=False)
  221. class CaseInsensitiveMixin:
  222. """
  223. Mixin to allow case-insensitive comparison of JSON values on MySQL.
  224. MySQL handles strings used in JSON context using the utf8mb4_bin collation.
  225. Because utf8mb4_bin is a binary collation, comparison of JSON values is
  226. case-sensitive.
  227. """
  228. def process_lhs(self, compiler, connection):
  229. lhs, lhs_params = super().process_lhs(compiler, connection)
  230. if connection.vendor == "mysql":
  231. return "LOWER(%s)" % lhs, lhs_params
  232. return lhs, lhs_params
  233. def process_rhs(self, compiler, connection):
  234. rhs, rhs_params = super().process_rhs(compiler, connection)
  235. if connection.vendor == "mysql":
  236. return "LOWER(%s)" % rhs, rhs_params
  237. return rhs, rhs_params
  238. class JSONExact(lookups.Exact):
  239. can_use_none_as_rhs = True
  240. def process_rhs(self, compiler, connection):
  241. rhs, rhs_params = super().process_rhs(compiler, connection)
  242. # Treat None lookup values as null.
  243. if rhs == "%s" and rhs_params == [None]:
  244. rhs_params = ["null"]
  245. if connection.vendor == "mysql":
  246. func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
  247. rhs = rhs % tuple(func)
  248. return rhs, rhs_params
  249. class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
  250. pass
  251. JSONField.register_lookup(DataContains)
  252. JSONField.register_lookup(ContainedBy)
  253. JSONField.register_lookup(HasKey)
  254. JSONField.register_lookup(HasKeys)
  255. JSONField.register_lookup(HasAnyKeys)
  256. JSONField.register_lookup(JSONExact)
  257. JSONField.register_lookup(JSONIContains)
  258. class KeyTransform(Transform):
  259. postgres_operator = "->"
  260. postgres_nested_operator = "#>"
  261. def __init__(self, key_name, *args, **kwargs):
  262. super().__init__(*args, **kwargs)
  263. self.key_name = str(key_name)
  264. def preprocess_lhs(self, compiler, connection):
  265. key_transforms = [self.key_name]
  266. previous = self.lhs
  267. while isinstance(previous, KeyTransform):
  268. key_transforms.insert(0, previous.key_name)
  269. previous = previous.lhs
  270. lhs, params = compiler.compile(previous)
  271. if connection.vendor == "oracle":
  272. # Escape string-formatting.
  273. key_transforms = [key.replace("%", "%%") for key in key_transforms]
  274. return lhs, params, key_transforms
  275. def as_mysql(self, compiler, connection):
  276. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  277. json_path = compile_json_path(key_transforms)
  278. return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
  279. def as_oracle(self, compiler, connection):
  280. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  281. json_path = compile_json_path(key_transforms)
  282. return (
  283. "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))"
  284. % ((lhs, json_path) * 2)
  285. ), tuple(params) * 2
  286. def as_postgresql(self, compiler, connection):
  287. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  288. if len(key_transforms) > 1:
  289. sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
  290. return sql, tuple(params) + (key_transforms,)
  291. try:
  292. lookup = int(self.key_name)
  293. except ValueError:
  294. lookup = self.key_name
  295. return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
  296. def as_sqlite(self, compiler, connection):
  297. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  298. json_path = compile_json_path(key_transforms)
  299. datatype_values = ",".join(
  300. [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
  301. )
  302. return (
  303. "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
  304. "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
  305. ) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
  306. class KeyTextTransform(KeyTransform):
  307. postgres_operator = "->>"
  308. postgres_nested_operator = "#>>"
  309. output_field = TextField()
  310. def as_mysql(self, compiler, connection):
  311. if connection.mysql_is_mariadb:
  312. # MariaDB doesn't support -> and ->> operators (see MDEV-13594).
  313. sql, params = super().as_mysql(compiler, connection)
  314. return "JSON_UNQUOTE(%s)" % sql, params
  315. else:
  316. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  317. json_path = compile_json_path(key_transforms)
  318. return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,)
  319. @classmethod
  320. def from_lookup(cls, lookup):
  321. transform, *keys = lookup.split(LOOKUP_SEP)
  322. if not keys:
  323. raise ValueError("Lookup must contain key or index transforms.")
  324. for key in keys:
  325. transform = cls(key, transform)
  326. return transform
  327. KT = KeyTextTransform.from_lookup
  328. class KeyTransformTextLookupMixin:
  329. """
  330. Mixin for combining with a lookup expecting a text lhs from a JSONField
  331. key lookup. On PostgreSQL, make use of the ->> operator instead of casting
  332. key values to text and performing the lookup on the resulting
  333. representation.
  334. """
  335. def __init__(self, key_transform, *args, **kwargs):
  336. if not isinstance(key_transform, KeyTransform):
  337. raise TypeError(
  338. "Transform should be an instance of KeyTransform in order to "
  339. "use this lookup."
  340. )
  341. key_text_transform = KeyTextTransform(
  342. key_transform.key_name,
  343. *key_transform.source_expressions,
  344. **key_transform.extra,
  345. )
  346. super().__init__(key_text_transform, *args, **kwargs)
  347. class KeyTransformIsNull(lookups.IsNull):
  348. # key__isnull=False is the same as has_key='key'
  349. def as_oracle(self, compiler, connection):
  350. sql, params = HasKeyOrArrayIndex(
  351. self.lhs.lhs,
  352. self.lhs.key_name,
  353. ).as_oracle(compiler, connection)
  354. if not self.rhs:
  355. return sql, params
  356. # Column doesn't have a key or IS NULL.
  357. lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
  358. return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
  359. def as_sqlite(self, compiler, connection):
  360. template = "JSON_TYPE(%s, %%s) IS NULL"
  361. if not self.rhs:
  362. template = "JSON_TYPE(%s, %%s) IS NOT NULL"
  363. return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
  364. compiler,
  365. connection,
  366. template=template,
  367. )
  368. class KeyTransformIn(lookups.In):
  369. def resolve_expression_parameter(self, compiler, connection, sql, param):
  370. sql, params = super().resolve_expression_parameter(
  371. compiler,
  372. connection,
  373. sql,
  374. param,
  375. )
  376. if (
  377. not hasattr(param, "as_sql")
  378. and not connection.features.has_native_json_field
  379. ):
  380. if connection.vendor == "oracle":
  381. value = json.loads(param)
  382. sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
  383. if isinstance(value, (list, dict)):
  384. sql = sql % "JSON_QUERY"
  385. else:
  386. sql = sql % "JSON_VALUE"
  387. elif connection.vendor == "mysql" or (
  388. connection.vendor == "sqlite"
  389. and params[0] not in connection.ops.jsonfield_datatype_values
  390. ):
  391. sql = "JSON_EXTRACT(%s, '$')"
  392. if connection.vendor == "mysql" and connection.mysql_is_mariadb:
  393. sql = "JSON_UNQUOTE(%s)" % sql
  394. return sql, params
  395. class KeyTransformExact(JSONExact):
  396. def process_rhs(self, compiler, connection):
  397. if isinstance(self.rhs, KeyTransform):
  398. return super(lookups.Exact, self).process_rhs(compiler, connection)
  399. rhs, rhs_params = super().process_rhs(compiler, connection)
  400. if connection.vendor == "oracle":
  401. func = []
  402. sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
  403. for value in rhs_params:
  404. value = json.loads(value)
  405. if isinstance(value, (list, dict)):
  406. func.append(sql % "JSON_QUERY")
  407. else:
  408. func.append(sql % "JSON_VALUE")
  409. rhs = rhs % tuple(func)
  410. elif connection.vendor == "sqlite":
  411. func = []
  412. for value in rhs_params:
  413. if value in connection.ops.jsonfield_datatype_values:
  414. func.append("%s")
  415. else:
  416. func.append("JSON_EXTRACT(%s, '$')")
  417. rhs = rhs % tuple(func)
  418. return rhs, rhs_params
  419. def as_oracle(self, compiler, connection):
  420. rhs, rhs_params = super().process_rhs(compiler, connection)
  421. if rhs_params == ["null"]:
  422. # Field has key and it's NULL.
  423. has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name)
  424. has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
  425. is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
  426. is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
  427. return (
  428. "%s AND %s" % (has_key_sql, is_null_sql),
  429. tuple(has_key_params) + tuple(is_null_params),
  430. )
  431. return super().as_sql(compiler, connection)
  432. class KeyTransformIExact(
  433. CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
  434. ):
  435. pass
  436. class KeyTransformIContains(
  437. CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
  438. ):
  439. pass
  440. class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
  441. pass
  442. class KeyTransformIStartsWith(
  443. CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
  444. ):
  445. pass
  446. class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
  447. pass
  448. class KeyTransformIEndsWith(
  449. CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
  450. ):
  451. pass
  452. class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
  453. pass
  454. class KeyTransformIRegex(
  455. CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
  456. ):
  457. pass
  458. class KeyTransformNumericLookupMixin:
  459. def process_rhs(self, compiler, connection):
  460. rhs, rhs_params = super().process_rhs(compiler, connection)
  461. if not connection.features.has_native_json_field:
  462. rhs_params = [json.loads(value) for value in rhs_params]
  463. return rhs, rhs_params
  464. class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
  465. pass
  466. class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
  467. pass
  468. class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
  469. pass
  470. class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
  471. pass
  472. KeyTransform.register_lookup(KeyTransformIn)
  473. KeyTransform.register_lookup(KeyTransformExact)
  474. KeyTransform.register_lookup(KeyTransformIExact)
  475. KeyTransform.register_lookup(KeyTransformIsNull)
  476. KeyTransform.register_lookup(KeyTransformIContains)
  477. KeyTransform.register_lookup(KeyTransformStartsWith)
  478. KeyTransform.register_lookup(KeyTransformIStartsWith)
  479. KeyTransform.register_lookup(KeyTransformEndsWith)
  480. KeyTransform.register_lookup(KeyTransformIEndsWith)
  481. KeyTransform.register_lookup(KeyTransformRegex)
  482. KeyTransform.register_lookup(KeyTransformIRegex)
  483. KeyTransform.register_lookup(KeyTransformLt)
  484. KeyTransform.register_lookup(KeyTransformLte)
  485. KeyTransform.register_lookup(KeyTransformGt)
  486. KeyTransform.register_lookup(KeyTransformGte)
  487. class KeyTransformFactory:
  488. def __init__(self, key_name):
  489. self.key_name = key_name
  490. def __call__(self, *args, **kwargs):
  491. return KeyTransform(self.key_name, *args, **kwargs)