2
0
Эх сурвалжийг харах

Refs #33308 -- Used get_db_prep_value() to adapt JSONFields.

Simon Charette 2 жил өмнө
parent
commit
5c23d9f0c3

+ 4 - 0
django/db/backends/base/operations.py

@@ -1,5 +1,6 @@
 import datetime
 import decimal
+import json
 from importlib import import_module
 
 import sqlparse
@@ -575,6 +576,9 @@ class BaseDatabaseOperations:
         """
         return value or None
 
+    def adapt_json_value(self, value, encoder):
+        return json.dumps(value, cls=encoder)
+
     def year_lookup_bounds_for_date_field(self, value, iso_year=False):
         """
         Return a two-elements list with the lower and upper bound to be used

+ 14 - 0
django/db/backends/postgresql/operations.py

@@ -1,4 +1,8 @@
+import json
+from functools import lru_cache, partial
+
 from psycopg2.extras import Inet
+from psycopg2.extras import Json as Jsonb
 
 from django.conf import settings
 from django.db.backends.base.operations import BaseDatabaseOperations
@@ -6,6 +10,13 @@ from django.db.backends.utils import split_tzname_delta
 from django.db.models.constants import OnConflict
 
 
+@lru_cache
+def get_json_dumps(encoder):
+    if encoder is None:
+        return json.dumps
+    return partial(json.dumps, cls=encoder)
+
+
 class DatabaseOperations(BaseDatabaseOperations):
     cast_char_field_without_max_length = "varchar"
     explain_prefix = "EXPLAIN"
@@ -308,6 +319,9 @@ class DatabaseOperations(BaseDatabaseOperations):
             return Inet(value)
         return None
 
+    def adapt_json_value(self, value, encoder):
+        return Jsonb(value, dumps=get_json_dumps(encoder))
+
     def subtract_temporals(self, internal_type, lhs, rhs):
         if internal_type == "DateField":
             lhs_sql, lhs_params = lhs

+ 14 - 5
django/db/models/fields/json.py

@@ -6,7 +6,11 @@ from django.db import NotSupportedError, connections, router
 from django.db.models import lookups
 from django.db.models.constants import LOOKUP_SEP
 from django.db.models.fields import TextField
-from django.db.models.lookups import PostgresOperatorLookup, Transform
+from django.db.models.lookups import (
+    FieldGetDbPrepValueMixin,
+    PostgresOperatorLookup,
+    Transform,
+)
 from django.utils.translation import gettext_lazy as _
 
 from . import Field
@@ -92,10 +96,15 @@ class JSONField(CheckFieldDefaultMixin, Field):
     def get_internal_type(self):
         return "JSONField"
 
-    def get_prep_value(self, value):
+    def get_db_prep_value(self, value, connection, prepared=False):
+        if hasattr(value, "as_sql"):
+            return value
+        return connection.ops.adapt_json_value(value, self.encoder)
+
+    def get_db_prep_save(self, value, connection):
         if value is None:
             return value
-        return json.dumps(value, cls=self.encoder)
+        return self.get_db_prep_value(value, connection)
 
     def get_transform(self, name):
         transform = super().get_transform(name)
@@ -141,7 +150,7 @@ def compile_json_path(key_transforms, include_root=True):
     return "".join(path)
 
 
-class DataContains(PostgresOperatorLookup):
+class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
     lookup_name = "contains"
     postgres_operator = "@>"
 
@@ -156,7 +165,7 @@ class DataContains(PostgresOperatorLookup):
         return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
 
 
-class ContainedBy(PostgresOperatorLookup):
+class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
     lookup_name = "contained_by"
     postgres_operator = "<@"