Browse Source

Refs #33374 -- Adjusted full match condition handling.

Adjusting WhereNode.as_sql() to raise an exception when encoutering a
full match just like with empty matches ensures that all case are
explicitly handled.
Simon Charette 2 years ago
parent
commit
76e37513e2

+ 6 - 0
django/core/exceptions.py

@@ -233,6 +233,12 @@ class EmptyResultSet(Exception):
     pass
 
 
+class FullResultSet(Exception):
+    """A database query predicate is matches everything."""
+
+    pass
+
+
 class SynchronousOnlyOperation(Exception):
     """The user tried to call a sync-only function from an async context."""
 

+ 9 - 5
django/db/backends/mysql/compiler.py

@@ -1,4 +1,4 @@
-from django.core.exceptions import FieldError
+from django.core.exceptions import FieldError, FullResultSet
 from django.db.models.expressions import Col
 from django.db.models.sql import compiler
 
@@ -40,12 +40,16 @@ class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
             "DELETE %s FROM"
             % self.quote_name_unless_alias(self.query.get_initial_alias())
         ]
-        from_sql, from_params = self.get_from_clause()
+        from_sql, params = self.get_from_clause()
         result.extend(from_sql)
-        where_sql, where_params = self.compile(where)
-        if where_sql:
+        try:
+            where_sql, where_params = self.compile(where)
+        except FullResultSet:
+            pass
+        else:
             result.append("WHERE %s" % where_sql)
-        return " ".join(result), tuple(from_params) + tuple(where_params)
+            params.extend(where_params)
+        return " ".join(result), tuple(params)
 
 
 class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):

+ 6 - 3
django/db/models/aggregates.py

@@ -1,7 +1,7 @@
 """
 Classes to represent the definitions of aggregate functions.
 """
-from django.core.exceptions import FieldError
+from django.core.exceptions import FieldError, FullResultSet
 from django.db.models.expressions import Case, Func, Star, When
 from django.db.models.fields import IntegerField
 from django.db.models.functions.comparison import Coalesce
@@ -104,8 +104,11 @@ class Aggregate(Func):
         extra_context["distinct"] = "DISTINCT " if self.distinct else ""
         if self.filter:
             if connection.features.supports_aggregate_filter_clause:
-                filter_sql, filter_params = self.filter.as_sql(compiler, connection)
-                if filter_sql:
+                try:
+                    filter_sql, filter_params = self.filter.as_sql(compiler, connection)
+                except FullResultSet:
+                    pass
+                else:
                     template = self.filter_template % extra_context.get(
                         "template", self.template
                     )

+ 7 - 10
django/db/models/expressions.py

@@ -7,7 +7,7 @@ from collections import defaultdict
 from decimal import Decimal
 from uuid import UUID
 
-from django.core.exceptions import EmptyResultSet, FieldError
+from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
 from django.db import DatabaseError, NotSupportedError, connection
 from django.db.models import fields
 from django.db.models.constants import LOOKUP_SEP
@@ -955,6 +955,8 @@ class Func(SQLiteNumericMixin, Expression):
                 if empty_result_set_value is NotImplemented:
                     raise
                 arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
+            except FullResultSet:
+                arg_sql, arg_params = compiler.compile(Value(True))
             sql_parts.append(arg_sql)
             params.extend(arg_params)
         data = {**self.extra, **extra_context}
@@ -1367,14 +1369,6 @@ class When(Expression):
         template_params = extra_context
         sql_params = []
         condition_sql, condition_params = compiler.compile(self.condition)
-        # Filters that match everything are handled as empty strings in the
-        # WHERE clause, but in a CASE WHEN expression they must use a predicate
-        # that's always True.
-        if condition_sql == "":
-            if connection.features.supports_boolean_expr_in_select_clause:
-                condition_sql, condition_params = compiler.compile(Value(True))
-            else:
-                condition_sql, condition_params = "1=1", ()
         template_params["condition"] = condition_sql
         result_sql, result_params = compiler.compile(self.result)
         template_params["result"] = result_sql
@@ -1461,14 +1455,17 @@ class Case(SQLiteNumericMixin, Expression):
         template_params = {**self.extra, **extra_context}
         case_parts = []
         sql_params = []
+        default_sql, default_params = compiler.compile(self.default)
         for case in self.cases:
             try:
                 case_sql, case_params = compiler.compile(case)
             except EmptyResultSet:
                 continue
+            except FullResultSet:
+                default_sql, default_params = compiler.compile(case.result)
+                break
             case_parts.append(case_sql)
             sql_params.extend(case_params)
-        default_sql, default_params = compiler.compile(self.default)
         if not case_parts:
             return default_sql, default_params
         case_joiner = case_joiner or self.case_joiner

+ 0 - 9
django/db/models/fields/__init__.py

@@ -1103,15 +1103,6 @@ class BooleanField(Field):
             defaults = {"form_class": form_class, "required": False}
         return super().formfield(**{**defaults, **kwargs})
 
-    def select_format(self, compiler, sql, params):
-        sql, params = super().select_format(compiler, sql, params)
-        # Filters that match everything are handled as empty strings in the
-        # WHERE clause, but in SELECT or GROUP BY list they must use a
-        # predicate that's always True.
-        if sql == "":
-            sql = "1"
-        return sql, params
-
 
 class CharField(Field):
     description = _("String (up to %(max_length)s)")

+ 25 - 12
django/db/models/sql/compiler.py

@@ -4,7 +4,7 @@ import re
 from functools import partial
 from itertools import chain
 
-from django.core.exceptions import EmptyResultSet, FieldError
+from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
 from django.db import DatabaseError, NotSupportedError
 from django.db.models.constants import LOOKUP_SEP
 from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value
@@ -169,7 +169,7 @@ class SQLCompiler:
                 expr = Ref(alias, expr)
             try:
                 sql, params = self.compile(expr)
-            except EmptyResultSet:
+            except (EmptyResultSet, FullResultSet):
                 continue
             sql, params = expr.select_format(self, sql, params)
             params_hash = make_hashable(params)
@@ -287,6 +287,8 @@ class SQLCompiler:
                     sql, params = "0", ()
                 else:
                     sql, params = self.compile(Value(empty_result_set_value))
+            except FullResultSet:
+                sql, params = self.compile(Value(True))
             else:
                 sql, params = col.select_format(self, sql, params)
             if alias is None and with_col_aliases:
@@ -721,9 +723,16 @@ class SQLCompiler:
                         raise
                     # Use a predicate that's always False.
                     where, w_params = "0 = 1", []
-                having, h_params = (
-                    self.compile(self.having) if self.having is not None else ("", [])
-                )
+                except FullResultSet:
+                    where, w_params = "", []
+                try:
+                    having, h_params = (
+                        self.compile(self.having)
+                        if self.having is not None
+                        else ("", [])
+                    )
+                except FullResultSet:
+                    having, h_params = "", []
                 result = ["SELECT"]
                 params = []
 
@@ -1817,11 +1826,12 @@ class SQLDeleteCompiler(SQLCompiler):
         )
 
     def _as_sql(self, query):
-        result = ["DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)]
-        where, params = self.compile(query.where)
-        if where:
-            result.append("WHERE %s" % where)
-        return " ".join(result), tuple(params)
+        delete = "DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)
+        try:
+            where, params = self.compile(query.where)
+        except FullResultSet:
+            return delete, ()
+        return f"{delete} WHERE {where}", tuple(params)
 
     def as_sql(self):
         """
@@ -1906,8 +1916,11 @@ class SQLUpdateCompiler(SQLCompiler):
             "UPDATE %s SET" % qn(table),
             ", ".join(values),
         ]
-        where, params = self.compile(self.query.where)
-        if where:
+        try:
+            where, params = self.compile(self.query.where)
+        except FullResultSet:
+            params = []
+        else:
             result.append("WHERE %s" % where)
         return " ".join(result), tuple(update_params + params)
 

+ 6 - 2
django/db/models/sql/datastructures.py

@@ -2,6 +2,7 @@
 Useful auxiliary data structures for query construction. Not useful outside
 the SQL domain.
 """
+from django.core.exceptions import FullResultSet
 from django.db.models.sql.constants import INNER, LOUTER
 
 
@@ -100,8 +101,11 @@ class Join:
             join_conditions.append("(%s)" % extra_sql)
             params.extend(extra_params)
         if self.filtered_relation:
-            extra_sql, extra_params = compiler.compile(self.filtered_relation)
-            if extra_sql:
+            try:
+                extra_sql, extra_params = compiler.compile(self.filtered_relation)
+            except FullResultSet:
+                pass
+            else:
                 join_conditions.append("(%s)" % extra_sql)
                 params.extend(extra_params)
         if not join_conditions:

+ 14 - 11
django/db/models/sql/where.py

@@ -4,7 +4,7 @@ Code to manage the creation and SQL rendering of 'where' constraints.
 import operator
 from functools import reduce
 
-from django.core.exceptions import EmptyResultSet
+from django.core.exceptions import EmptyResultSet, FullResultSet
 from django.db.models.expressions import Case, When
 from django.db.models.lookups import Exact
 from django.utils import tree
@@ -145,6 +145,8 @@ class WhereNode(tree.Node):
                 sql, params = compiler.compile(child)
             except EmptyResultSet:
                 empty_needed -= 1
+            except FullResultSet:
+                full_needed -= 1
             else:
                 if sql:
                     result.append(sql)
@@ -158,24 +160,25 @@ class WhereNode(tree.Node):
             # counts.
             if empty_needed == 0:
                 if self.negated:
-                    return "", []
+                    raise FullResultSet
                 else:
                     raise EmptyResultSet
             if full_needed == 0:
                 if self.negated:
                     raise EmptyResultSet
                 else:
-                    return "", []
+                    raise FullResultSet
         conn = " %s " % self.connector
         sql_string = conn.join(result)
-        if sql_string:
-            if self.negated:
-                # Some backends (Oracle at least) need parentheses
-                # around the inner SQL in the negated case, even if the
-                # inner SQL contains just a single expression.
-                sql_string = "NOT (%s)" % sql_string
-            elif len(result) > 1 or self.resolved:
-                sql_string = "(%s)" % sql_string
+        if not sql_string:
+            raise FullResultSet
+        if self.negated:
+            # Some backends (Oracle at least) need parentheses around the inner
+            # SQL in the negated case, even if the inner SQL contains just a
+            # single expression.
+            sql_string = "NOT (%s)" % sql_string
+        elif len(result) > 1 or self.resolved:
+            sql_string = "(%s)" % sql_string
         return sql_string, result_params
 
     def get_group_by_cols(self):

+ 11 - 0
docs/ref/exceptions.txt

@@ -42,6 +42,17 @@ Django core exception classes are defined in ``django.core.exceptions``.
     return any results. Most Django projects won't encounter this exception,
     but it might be useful for implementing custom lookups and expressions.
 
+``FullResultSet``
+-----------------
+
+.. exception:: FullResultSet
+
+.. versionadded:: 4.2
+
+    ``FullResultSet`` may be raised during query generation if a query will
+    match everything. Most Django projects won't encounter this exception, but
+    it might be useful for implementing custom lookups and expressions.
+
 ``FieldDoesNotExist``
 ---------------------
 

+ 17 - 2
tests/annotations/tests.py

@@ -24,7 +24,15 @@ from django.db.models import (
     When,
 )
 from django.db.models.expressions import RawSQL
-from django.db.models.functions import Coalesce, ExtractYear, Floor, Length, Lower, Trim
+from django.db.models.functions import (
+    Cast,
+    Coalesce,
+    ExtractYear,
+    Floor,
+    Length,
+    Lower,
+    Trim,
+)
 from django.test import TestCase, skipUnlessDBFeature
 from django.test.utils import register_lookup
 
@@ -282,6 +290,13 @@ class NonAggregateAnnotationTestCase(TestCase):
         self.assertEqual(len(books), Book.objects.count())
         self.assertTrue(all(book.selected for book in books))
 
+    def test_full_expression_wrapped_annotation(self):
+        books = Book.objects.annotate(
+            selected=Coalesce(~Q(pk__in=[]), True),
+        )
+        self.assertEqual(len(books), Book.objects.count())
+        self.assertTrue(all(book.selected for book in books))
+
     def test_full_expression_annotation_with_aggregation(self):
         qs = Book.objects.filter(isbn="159059725").annotate(
             selected=ExpressionWrapper(~Q(pk__in=[]), output_field=BooleanField()),
@@ -292,7 +307,7 @@ class NonAggregateAnnotationTestCase(TestCase):
     def test_aggregate_over_full_expression_annotation(self):
         qs = Book.objects.annotate(
             selected=ExpressionWrapper(~Q(pk__in=[]), output_field=BooleanField()),
-        ).aggregate(Sum("selected"))
+        ).aggregate(selected__sum=Sum(Cast("selected", IntegerField())))
         self.assertEqual(qs["selected__sum"], Book.objects.count())
 
     def test_empty_queryset_annotation(self):

+ 13 - 7
tests/queries/tests.py

@@ -5,7 +5,7 @@ import unittest
 from operator import attrgetter
 from threading import Lock
 
-from django.core.exceptions import EmptyResultSet, FieldError
+from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
 from django.db import DEFAULT_DB_ALIAS, connection
 from django.db.models import CharField, Count, Exists, F, Max, OuterRef, Q
 from django.db.models.expressions import RawSQL
@@ -3588,7 +3588,8 @@ class WhereNodeTest(SimpleTestCase):
         with self.assertRaises(EmptyResultSet):
             w.as_sql(compiler, connection)
         w.negate()
-        self.assertEqual(w.as_sql(compiler, connection), ("", []))
+        with self.assertRaises(FullResultSet):
+            w.as_sql(compiler, connection)
         w = WhereNode(children=[self.DummyNode(), self.DummyNode()])
         self.assertEqual(w.as_sql(compiler, connection), ("(dummy AND dummy)", []))
         w.negate()
@@ -3597,7 +3598,8 @@ class WhereNodeTest(SimpleTestCase):
         with self.assertRaises(EmptyResultSet):
             w.as_sql(compiler, connection)
         w.negate()
-        self.assertEqual(w.as_sql(compiler, connection), ("", []))
+        with self.assertRaises(FullResultSet):
+            w.as_sql(compiler, connection)
 
     def test_empty_full_handling_disjunction(self):
         compiler = WhereNodeTest.MockCompiler()
@@ -3605,7 +3607,8 @@ class WhereNodeTest(SimpleTestCase):
         with self.assertRaises(EmptyResultSet):
             w.as_sql(compiler, connection)
         w.negate()
-        self.assertEqual(w.as_sql(compiler, connection), ("", []))
+        with self.assertRaises(FullResultSet):
+            w.as_sql(compiler, connection)
         w = WhereNode(children=[self.DummyNode(), self.DummyNode()], connector=OR)
         self.assertEqual(w.as_sql(compiler, connection), ("(dummy OR dummy)", []))
         w.negate()
@@ -3619,7 +3622,8 @@ class WhereNodeTest(SimpleTestCase):
         compiler = WhereNodeTest.MockCompiler()
         empty_w = WhereNode()
         w = WhereNode(children=[empty_w, empty_w])
-        self.assertEqual(w.as_sql(compiler, connection), ("", []))
+        with self.assertRaises(FullResultSet):
+            w.as_sql(compiler, connection)
         w.negate()
         with self.assertRaises(EmptyResultSet):
             w.as_sql(compiler, connection)
@@ -3627,9 +3631,11 @@ class WhereNodeTest(SimpleTestCase):
         with self.assertRaises(EmptyResultSet):
             w.as_sql(compiler, connection)
         w.negate()
-        self.assertEqual(w.as_sql(compiler, connection), ("", []))
+        with self.assertRaises(FullResultSet):
+            w.as_sql(compiler, connection)
         w = WhereNode(children=[empty_w, NothingNode()], connector=OR)
-        self.assertEqual(w.as_sql(compiler, connection), ("", []))
+        with self.assertRaises(FullResultSet):
+            w.as_sql(compiler, connection)
         w = WhereNode(children=[empty_w, NothingNode()], connector=AND)
         with self.assertRaises(EmptyResultSet):
             w.as_sql(compiler, connection)