2
0
Эх сурвалжийг харах

Refs #30581 -- Added Q.check() hook.

Gagaro 3 жил өмнө
parent
commit
5d91dc8ee3

+ 29 - 0
django/db/models/query_utils.py

@@ -8,12 +8,16 @@ circular import difficulties.
 import copy
 import functools
 import inspect
+import logging
 from collections import namedtuple
 
 from django.core.exceptions import FieldError
+from django.db import DEFAULT_DB_ALIAS, DatabaseError
 from django.db.models.constants import LOOKUP_SEP
 from django.utils import tree
 
+logger = logging.getLogger("django.db.models")
+
 # PathInfo is used when converting lookups (fk__somecol). The contents
 # describe the relation in Model terms (model Options and Fields for both
 # sides of the relation. The join_field is the field backing the relation.
@@ -110,6 +114,31 @@ class Q(tree.Node):
             else:
                 yield child
 
+    def check(self, against, using=DEFAULT_DB_ALIAS):
+        """
+        Do a database query to check if the expressions of the Q instance
+        matches against the expressions.
+        """
+        # Avoid circular imports.
+        from django.db.models import Value
+        from django.db.models.sql import Query
+        from django.db.models.sql.constants import SINGLE
+
+        query = Query(None)
+        for name, value in against.items():
+            if not hasattr(value, "resolve_expression"):
+                value = Value(value)
+            query.add_annotation(value, name, select=False)
+        query.add_annotation(Value(1), "_check")
+        # This will raise a FieldError if a field is missing in "against".
+        query.add_q(self)
+        compiler = query.get_compiler(using=using)
+        try:
+            return compiler.execute_sql(SINGLE) is not None
+        except DatabaseError as e:
+            logger.warning("Got a database error calling check() on %r: %s", self, e)
+            return True
+
     def deconstruct(self):
         path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
         if path.startswith("django.db.models.query_utils"):

+ 39 - 1
tests/queries/test_q.py

@@ -1,3 +1,4 @@
+from django.core.exceptions import FieldError
 from django.db.models import (
     BooleanField,
     Exists,
@@ -10,7 +11,7 @@ from django.db.models import (
 from django.db.models.expressions import RawSQL
 from django.db.models.functions import Lower
 from django.db.models.sql.where import NothingNode
-from django.test import SimpleTestCase
+from django.test import SimpleTestCase, TestCase
 
 from .models import Tag
 
@@ -214,3 +215,40 @@ class QTests(SimpleTestCase):
         )
         flatten = list(q.flatten())
         self.assertEqual(len(flatten), 7)
+
+
+class QCheckTests(TestCase):
+    def test_basic(self):
+        q = Q(price__gt=20)
+        self.assertIs(q.check({"price": 30}), True)
+        self.assertIs(q.check({"price": 10}), False)
+
+    def test_expression(self):
+        q = Q(name="test")
+        self.assertIs(q.check({"name": Lower(Value("TeSt"))}), True)
+        self.assertIs(q.check({"name": Value("other")}), False)
+
+    def test_missing_field(self):
+        q = Q(description__startswith="prefix")
+        msg = "Cannot resolve keyword 'description' into field."
+        with self.assertRaisesMessage(FieldError, msg):
+            q.check({"name": "test"})
+
+    def test_boolean_expression(self):
+        q = Q(ExpressionWrapper(Q(price__gt=20), output_field=BooleanField()))
+        self.assertIs(q.check({"price": 25}), True)
+        self.assertIs(q.check({"price": Value(10)}), False)
+
+    def test_rawsql(self):
+        """
+        RawSQL expressions cause a database error because "price" cannot be
+        replaced by its value. In this case, Q.check() logs a warning and
+        return True.
+        """
+        q = Q(RawSQL("price > %s", params=(20,), output_field=BooleanField()))
+        with self.assertLogs("django.db.models", "WARNING") as cm:
+            self.assertIs(q.check({"price": 10}), True)
+        self.assertIn(
+            f"Got a database error calling check() on {q!r}: ",
+            cm.records[0].getMessage(),
+        )