Pārlūkot izejas kodu

Refs #29850 -- Added exclusion support to window frames.

Sarah Boyce 1 gadu atpakaļ
vecāks
revīzija
e4d012ca05

+ 1 - 0
django/db/backends/base/features.py

@@ -263,6 +263,7 @@ class BaseDatabaseFeatures:
     # Does the backend support window expressions (expression OVER (...))?
     supports_over_clause = False
     supports_frame_range_fixed_distance = False
+    supports_frame_exclusion = False
     only_supports_unbounded_with_preceding_and_following = False
 
     # Does the backend support CAST with precision?

+ 4 - 0
django/db/backends/oracle/features.py

@@ -162,3 +162,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     @cached_property
     def supports_primitives_in_json_field(self):
         return self.connection.oracle_version >= (21,)
+
+    @cached_property
+    def supports_frame_exclusion(self):
+        return self.connection.oracle_version >= (21,)

+ 1 - 0
django/db/backends/postgresql/features.py

@@ -61,6 +61,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     """
     requires_casted_case_in_updates = True
     supports_over_clause = True
+    supports_frame_exclusion = True
     only_supports_unbounded_with_preceding_and_following = True
     supports_aggregate_filter_clause = True
     supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"}

+ 1 - 0
django/db/backends/sqlite3/features.py

@@ -32,6 +32,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     can_defer_constraint_checks = True
     supports_over_clause = True
     supports_frame_range_fixed_distance = Database.sqlite_version_info >= (3, 28, 0)
+    supports_frame_exclusion = Database.sqlite_version_info >= (3, 28, 0)
     supports_aggregate_filter_clause = Database.sqlite_version_info >= (3, 30, 1)
     supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0)
     # NULLS LAST/FIRST emulation on < 3.30 requires subquery wrapping.

+ 2 - 0
django/db/models/__init__.py

@@ -34,6 +34,7 @@ from django.db.models.expressions import (
     When,
     Window,
     WindowFrame,
+    WindowFrameExclusion,
 )
 from django.db.models.fields import *  # NOQA
 from django.db.models.fields import __all__ as fields_all
@@ -91,6 +92,7 @@ __all__ += [
     "When",
     "Window",
     "WindowFrame",
+    "WindowFrameExclusion",
     "FileField",
     "ImageField",
     "GeneratedField",

+ 30 - 2
django/db/models/expressions.py

@@ -4,6 +4,7 @@ import functools
 import inspect
 from collections import defaultdict
 from decimal import Decimal
+from enum import Enum
 from types import NoneType
 from uuid import UUID
 
@@ -1848,6 +1849,16 @@ class Window(SQLiteNumericMixin, Expression):
         return group_by_cols
 
 
+class WindowFrameExclusion(Enum):
+    CURRENT_ROW = "CURRENT ROW"
+    GROUP = "GROUP"
+    TIES = "TIES"
+    NO_OTHERS = "NO OTHERS"
+
+    def __repr__(self):
+        return f"{self.__class__.__qualname__}.{self._name_}"
+
+
 class WindowFrame(Expression):
     """
     Model the frame clause in window expressions. There are two types of frame
@@ -1857,11 +1868,17 @@ class WindowFrame(Expression):
     row in the frame).
     """
 
-    template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
+    template = "%(frame_type)s BETWEEN %(start)s AND %(end)s%(exclude)s"
 
-    def __init__(self, start=None, end=None):
+    def __init__(self, start=None, end=None, exclusion=None):
         self.start = Value(start)
         self.end = Value(end)
+        if not isinstance(exclusion, (NoneType, WindowFrameExclusion)):
+            raise TypeError(
+                f"{self.__class__.__qualname__}.exclusion must be a "
+                "WindowFrameExclusion instance."
+            )
+        self.exclusion = exclusion
 
     def set_source_expressions(self, exprs):
         self.start, self.end = exprs
@@ -1869,17 +1886,27 @@ class WindowFrame(Expression):
     def get_source_expressions(self):
         return [self.start, self.end]
 
+    def get_exclusion(self):
+        if self.exclusion is None:
+            return ""
+        return f" EXCLUDE {self.exclusion.value}"
+
     def as_sql(self, compiler, connection):
         connection.ops.check_expression_support(self)
         start, end = self.window_frame_start_end(
             connection, self.start.value, self.end.value
         )
+        if self.exclusion and not connection.features.supports_frame_exclusion:
+            raise NotSupportedError(
+                "This backend does not support window frame exclusions."
+            )
         return (
             self.template
             % {
                 "frame_type": self.frame_type,
                 "start": start,
                 "end": end,
+                "exclude": self.get_exclusion(),
             },
             [],
         )
@@ -1912,6 +1939,7 @@ class WindowFrame(Expression):
             "frame_type": self.frame_type,
             "start": start,
             "end": end,
+            "exclude": self.get_exclusion(),
         }
 
     def window_frame_start_end(self, connection, start, end):

+ 32 - 2
docs/ref/models/expressions.txt

@@ -889,7 +889,7 @@ Frames
 For a window frame, you can choose either a range-based sequence of rows or an
 ordinary sequence of rows.
 
-.. class:: ValueRange(start=None, end=None)
+.. class:: ValueRange(start=None, end=None, exclusion=None)
 
     .. attribute:: frame_type
 
@@ -899,18 +899,48 @@ ordinary sequence of rows.
     the standard start and end points, such as ``CURRENT ROW`` and ``UNBOUNDED
     FOLLOWING``.
 
-.. class:: RowRange(start=None, end=None)
+    .. versionchanged:: 5.1
+
+        The ``exclusion`` argument was added.
+
+.. class:: RowRange(start=None, end=None, exclusion=None)
 
     .. attribute:: frame_type
 
         This attribute is set to ``'ROWS'``.
 
+    .. versionchanged:: 5.1
+
+        The ``exclusion`` argument was added.
+
 Both classes return SQL with the template:
 
 .. code-block:: sql
 
     %(frame_type)s BETWEEN %(start)s AND %(end)s
 
+.. class:: WindowFrameExclusion
+
+    .. versionadded:: 5.1
+
+    .. attribute:: CURRENT_ROW
+
+    .. attribute:: GROUP
+
+    .. attribute:: TIES
+
+    .. attribute:: NO_OTHERS
+
+The ``exclusion`` argument allows excluding rows
+(:attr:`~WindowFrameExclusion.CURRENT_ROW`), groups
+(:attr:`~WindowFrameExclusion.GROUP`), and ties
+(:attr:`~WindowFrameExclusion.TIES`) from the window frames on supported
+databases:
+
+.. code-block:: sql
+
+    %(frame_type)s BETWEEN %(start)s AND %(end)s EXCLUDE %(exclusion)s
+
 Frames narrow the rows that are used for computing the result. They shift from
 some start point to some specified end point. Frames can be used with and
 without partitions, but it's often a good idea to specify an ordering of the

+ 5 - 0
docs/releases/5.1.txt

@@ -174,6 +174,11 @@ Models
 * :class:`~django.db.models.expressions.RowRange` now accepts positive integers
   for the ``start`` argument and negative integers for the ``end`` argument.
 
+* The new ``exclusion`` argument of
+  :class:`~django.db.models.expressions.RowRange` and
+  :class:`~django.db.models.expressions.ValueRange` allows excluding rows,
+  groups, and ties from the window frames.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 

+ 234 - 1
tests/expressions_window/tests.py

@@ -22,6 +22,7 @@ from django.db.models import (
     When,
     Window,
     WindowFrame,
+    WindowFrameExclusion,
 )
 from django.db.models.fields.json import KeyTextTransform, KeyTransform
 from django.db.models.functions import (
@@ -41,7 +42,7 @@ from django.db.models.functions import (
     Upper,
 )
 from django.db.models.lookups import Exact
-from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
+from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
 from django.test.utils import CaptureQueriesContext
 
 from .models import Classification, Detail, Employee, PastEmployeeDepartment
@@ -1211,6 +1212,47 @@ class WindowFunctionTests(TestCase):
             ordered=False,
         )
 
+    @skipUnlessDBFeature(
+        "supports_frame_exclusion", "supports_frame_range_fixed_distance"
+    )
+    def test_range_exclude_current(self):
+        qs = Employee.objects.annotate(
+            sum=Window(
+                expression=Sum("salary"),
+                order_by=F("salary").asc(),
+                partition_by="department",
+                frame=ValueRange(end=2, exclusion=WindowFrameExclusion.CURRENT_ROW),
+            )
+        ).order_by("department", "salary")
+        self.assertIn(
+            "RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING EXCLUDE CURRENT ROW",
+            str(qs.query),
+        )
+        self.assertQuerySetEqual(
+            qs,
+            [
+                ("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), None),
+                ("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 82000),
+                ("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 82000),
+                ("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 127000),
+                ("Moore", 34000, "IT", datetime.date(2013, 8, 1), None),
+                ("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 34000),
+                ("Johnson", 80000, "Management", datetime.date(2005, 7, 1), None),
+                ("Miller", 100000, "Management", datetime.date(2005, 6, 1), 80000),
+                ("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), None),
+                ("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 38000),
+                ("Brown", 53000, "Sales", datetime.date(2009, 9, 1), None),
+                ("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 53000),
+            ],
+            transform=lambda row: (
+                row.name,
+                row.salary,
+                row.department,
+                row.hire_date,
+                row.sum,
+            ),
+        )
+
     def test_range_unbound(self):
         """A query with RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING."""
         qs = Employee.objects.annotate(
@@ -1289,6 +1331,190 @@ class WindowFunctionTests(TestCase):
             ),
         )
 
+    @skipUnlessDBFeature("supports_frame_exclusion")
+    def test_row_range_rank_exclude_current_row(self):
+        qs = Employee.objects.annotate(
+            avg_salary_cohort=Window(
+                expression=Avg("salary"),
+                order_by=[F("hire_date").asc(), F("name").desc()],
+                frame=RowRange(
+                    start=-1, end=1, exclusion=WindowFrameExclusion.CURRENT_ROW
+                ),
+            )
+        ).order_by("hire_date")
+        self.assertIn(
+            "ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING EXCLUDE CURRENT ROW",
+            str(qs.query),
+        )
+        self.assertQuerySetEqual(
+            qs,
+            [
+                ("Miller", 100000, "Management", datetime.date(2005, 6, 1), 80000),
+                ("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 72500),
+                ("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 67500),
+                ("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 45000),
+                ("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 46000),
+                ("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 49000),
+                ("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 37500),
+                ("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 56500),
+                ("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 39000),
+                ("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 55000),
+                ("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 37000),
+                ("Moore", 34000, "IT", datetime.date(2013, 8, 1), 50000),
+            ],
+            transform=lambda row: (
+                row.name,
+                row.salary,
+                row.department,
+                row.hire_date,
+                row.avg_salary_cohort,
+            ),
+        )
+
+    @skipUnlessDBFeature("supports_frame_exclusion")
+    def test_row_range_rank_exclude_group(self):
+        qs = Employee.objects.annotate(
+            avg_salary_cohort=Window(
+                expression=Avg("salary"),
+                order_by=[F("hire_date").asc(), F("name").desc()],
+                frame=RowRange(start=-1, end=1, exclusion=WindowFrameExclusion.GROUP),
+            )
+        ).order_by("hire_date")
+        self.assertIn(
+            "ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING EXCLUDE GROUP",
+            str(qs.query),
+        )
+        self.assertQuerySetEqual(
+            qs,
+            [
+                ("Miller", 100000, "Management", datetime.date(2005, 6, 1), 80000),
+                ("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 72500),
+                ("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 67500),
+                ("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 45000),
+                ("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 46000),
+                ("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 49000),
+                ("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 37500),
+                ("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 56500),
+                ("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 39000),
+                ("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 55000),
+                ("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 37000),
+                ("Moore", 34000, "IT", datetime.date(2013, 8, 1), 50000),
+            ],
+            transform=lambda row: (
+                row.name,
+                row.salary,
+                row.department,
+                row.hire_date,
+                row.avg_salary_cohort,
+            ),
+        )
+
+    @skipUnlessDBFeature("supports_frame_exclusion")
+    def test_row_range_rank_exclude_ties(self):
+        qs = Employee.objects.annotate(
+            sum_salary_cohort=Window(
+                expression=Sum("salary"),
+                order_by=[F("hire_date").asc(), F("name").desc()],
+                frame=RowRange(start=-1, end=1, exclusion=WindowFrameExclusion.TIES),
+            )
+        ).order_by("hire_date")
+        self.assertIn(
+            "ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING EXCLUDE TIES",
+            str(qs.query),
+        )
+        self.assertQuerySetEqual(
+            qs,
+            [
+                ("Miller", 100000, "Management", datetime.date(2005, 6, 1), 180000),
+                ("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 225000),
+                ("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 180000),
+                ("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 145000),
+                ("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 137000),
+                ("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 135000),
+                ("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 128000),
+                ("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 151000),
+                ("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 138000),
+                ("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 150000),
+                ("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 124000),
+                ("Moore", 34000, "IT", datetime.date(2013, 8, 1), 84000),
+            ],
+            transform=lambda row: (
+                row.name,
+                row.salary,
+                row.department,
+                row.hire_date,
+                row.sum_salary_cohort,
+            ),
+        )
+
+    @skipUnlessDBFeature("supports_frame_exclusion")
+    def test_row_range_rank_exclude_no_others(self):
+        qs = Employee.objects.annotate(
+            sum_salary_cohort=Window(
+                expression=Sum("salary"),
+                order_by=[F("hire_date").asc(), F("name").desc()],
+                frame=RowRange(
+                    start=-1, end=1, exclusion=WindowFrameExclusion.NO_OTHERS
+                ),
+            )
+        ).order_by("hire_date")
+        self.assertIn(
+            "ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING EXCLUDE NO OTHERS",
+            str(qs.query),
+        )
+        self.assertQuerySetEqual(
+            qs,
+            [
+                ("Miller", 100000, "Management", datetime.date(2005, 6, 1), 180000),
+                ("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 225000),
+                ("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 180000),
+                ("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 145000),
+                ("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 137000),
+                ("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 135000),
+                ("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 128000),
+                ("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 151000),
+                ("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 138000),
+                ("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 150000),
+                ("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 124000),
+                ("Moore", 34000, "IT", datetime.date(2013, 8, 1), 84000),
+            ],
+            transform=lambda row: (
+                row.name,
+                row.salary,
+                row.department,
+                row.hire_date,
+                row.sum_salary_cohort,
+            ),
+        )
+
+    @skipIfDBFeature("supports_frame_exclusion")
+    def test_unsupported_frame_exclusion_raises_error(self):
+        msg = "This backend does not support window frame exclusions."
+        with self.assertRaisesMessage(NotSupportedError, msg):
+            list(
+                Employee.objects.annotate(
+                    avg_salary_cohort=Window(
+                        expression=Avg("salary"),
+                        order_by=[F("hire_date").asc(), F("name").desc()],
+                        frame=RowRange(
+                            start=-1, end=1, exclusion=WindowFrameExclusion.CURRENT_ROW
+                        ),
+                    )
+                )
+            )
+
+    @skipUnlessDBFeature("supports_frame_exclusion")
+    def test_invalid_frame_exclusion_value_raises_error(self):
+        msg = "RowRange.exclusion must be a WindowFrameExclusion instance."
+        with self.assertRaisesMessage(TypeError, msg):
+            Employee.objects.annotate(
+                avg_salary_cohort=Window(
+                    expression=Avg("salary"),
+                    order_by=[F("hire_date").asc(), F("name").desc()],
+                    frame=RowRange(start=-1, end=1, exclusion="RUBBISH"),
+                )
+            )
+
     def test_row_range_rank(self):
         """
         A query with ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING.
@@ -1735,6 +1961,13 @@ class NonQueryWindowTests(SimpleTestCase):
             repr(RowRange(start=1, end=2)),
             "<RowRange: ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING>",
         )
+        self.assertEqual(
+            repr(RowRange(start=1, end=2, exclusion=WindowFrameExclusion.CURRENT_ROW)),
+            "<RowRange: ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING EXCLUDE CURRENT ROW>",
+        )
+
+    def test_window_frame_exclusion_repr(self):
+        self.assertEqual(repr(WindowFrameExclusion.TIES), "WindowFrameExclusion.TIES")
 
     def test_empty_group_by_cols(self):
         window = Window(expression=Sum("pk"))