浏览代码

Refs #29850 -- Added RowRange support for positive integer start and negative integer end.

Sarah Boyce 1 年之前
父节点
当前提交
6375cee490

+ 36 - 28
django/db/backends/base/operations.py

@@ -714,42 +714,50 @@ class BaseDatabaseOperations:
             "This backend does not support %s subtraction." % internal_type
             "This backend does not support %s subtraction." % internal_type
         )
         )
 
 
-    def window_frame_start(self, start):
-        if isinstance(start, int):
-            if start < 0:
-                return "%d %s" % (abs(start), self.PRECEDING)
-            elif start == 0:
+    def window_frame_value(self, value):
+        if isinstance(value, int):
+            if value == 0:
                 return self.CURRENT_ROW
                 return self.CURRENT_ROW
-        elif start is None:
-            return self.UNBOUNDED_PRECEDING
-        raise ValueError(
-            "start argument must be a negative integer, zero, or None, but got '%s'."
-            % start
-        )
-
-    def window_frame_end(self, end):
-        if isinstance(end, int):
-            if end == 0:
-                return self.CURRENT_ROW
-            elif end > 0:
-                return "%d %s" % (end, self.FOLLOWING)
-        elif end is None:
-            return self.UNBOUNDED_FOLLOWING
-        raise ValueError(
-            "end argument must be a positive integer, zero, or None, but got '%s'."
-            % end
-        )
+            elif value < 0:
+                return "%d %s" % (abs(value), self.PRECEDING)
+            else:
+                return "%d %s" % (value, self.FOLLOWING)
 
 
     def window_frame_rows_start_end(self, start=None, end=None):
     def window_frame_rows_start_end(self, start=None, end=None):
         """
         """
         Return SQL for start and end points in an OVER clause window frame.
         Return SQL for start and end points in an OVER clause window frame.
         """
         """
-        if not self.connection.features.supports_over_clause:
-            raise NotSupportedError("This backend does not support window expressions.")
-        return self.window_frame_start(start), self.window_frame_end(end)
+        if isinstance(start, int) and isinstance(end, int) and start > end:
+            raise ValueError("start cannot be greater than end.")
+        if start is not None and not isinstance(start, int):
+            raise ValueError(
+                f"start argument must be an integer, zero, or None, but got '{start}'."
+            )
+        if end is not None and not isinstance(end, int):
+            raise ValueError(
+                f"end argument must be an integer, zero, or None, but got '{end}'."
+            )
+        start_ = self.window_frame_value(start) or self.UNBOUNDED_PRECEDING
+        end_ = self.window_frame_value(end) or self.UNBOUNDED_FOLLOWING
+        return start_, end_
 
 
     def window_frame_range_start_end(self, start=None, end=None):
     def window_frame_range_start_end(self, start=None, end=None):
-        start_, end_ = self.window_frame_rows_start_end(start, end)
+        if (start is not None and not isinstance(start, int)) or (
+            isinstance(start, int) and start > 0
+        ):
+            raise ValueError(
+                "start argument must be a negative integer, zero, or None, "
+                "but got '%s'." % start
+            )
+        if (end is not None and not isinstance(end, int)) or (
+            isinstance(end, int) and end < 0
+        ):
+            raise ValueError(
+                "end argument must be a positive integer, zero, or None, but got '%s'."
+                % end
+            )
+        start_ = self.window_frame_value(start) or self.UNBOUNDED_PRECEDING
+        end_ = self.window_frame_value(end) or self.UNBOUNDED_FOLLOWING
         features = self.connection.features
         features = self.connection.features
         if features.only_supports_unbounded_with_preceding_and_following and (
         if features.only_supports_unbounded_with_preceding_and_following and (
             (start and start < 0) or (end and end > 0)
             (start and start < 0) or (end and end > 0)

+ 4 - 0
django/db/models/expressions.py

@@ -1895,6 +1895,8 @@ class WindowFrame(Expression):
             start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING)
             start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING)
         elif self.start.value is not None and self.start.value == 0:
         elif self.start.value is not None and self.start.value == 0:
             start = connection.ops.CURRENT_ROW
             start = connection.ops.CURRENT_ROW
+        elif self.start.value is not None and self.start.value > 0:
+            start = "%d %s" % (self.start.value, connection.ops.FOLLOWING)
         else:
         else:
             start = connection.ops.UNBOUNDED_PRECEDING
             start = connection.ops.UNBOUNDED_PRECEDING
 
 
@@ -1902,6 +1904,8 @@ class WindowFrame(Expression):
             end = "%d %s" % (self.end.value, connection.ops.FOLLOWING)
             end = "%d %s" % (self.end.value, connection.ops.FOLLOWING)
         elif self.end.value is not None and self.end.value == 0:
         elif self.end.value is not None and self.end.value == 0:
             end = connection.ops.CURRENT_ROW
             end = connection.ops.CURRENT_ROW
+        elif self.end.value is not None and self.end.value < 0:
+            end = "%d %s" % (abs(self.end.value), connection.ops.PRECEDING)
         else:
         else:
             end = connection.ops.UNBOUNDED_FOLLOWING
             end = connection.ops.UNBOUNDED_FOLLOWING
         return self.template % {
         return self.template % {

+ 11 - 3
docs/ref/models/expressions.txt

@@ -923,9 +923,12 @@ SQL generated by the ORM and is by default ``UNBOUNDED FOLLOWING``. The default
 frame includes all rows from the partition to the last row in the set.
 frame includes all rows from the partition to the last row in the set.
 
 
 The accepted values for the ``start`` and ``end`` arguments are ``None``, an
 The accepted values for the ``start`` and ``end`` arguments are ``None``, an
-integer, or zero. A negative integer for ``start`` results in ``N preceding``,
-while ``None`` yields ``UNBOUNDED PRECEDING``. For both ``start`` and ``end``,
-zero will return ``CURRENT ROW``. Positive integers are accepted for ``end``.
+integer, or zero. A negative integer for ``start`` results in ``N PRECEDING``,
+while ``None`` yields ``UNBOUNDED PRECEDING``. In ``ROWS`` mode, a positive
+integer can be used for ```start`` resulting in ``N FOLLOWING``. Positive
+integers are accepted for ``end`` and results in ``N FOLLOWING``. In ``ROWS``
+mode, a negative integer can be used for ```end`` resulting in ``N PRECEDING``.
+For both ``start`` and ``end``, zero will return ``CURRENT ROW``.
 
 
 There's a difference in what ``CURRENT ROW`` includes. When specified in
 There's a difference in what ``CURRENT ROW`` includes. When specified in
 ``ROWS`` mode, the frame starts or ends with the current row. When specified in
 ``ROWS`` mode, the frame starts or ends with the current row. When specified in
@@ -970,6 +973,11 @@ released between twelve months before and twelve months after each movie:
     ...     ),
     ...     ),
     ... )
     ... )
 
 
+.. versionchanged:: 5.1
+
+    Support for positive integer ``start`` and negative integer ``end`` was
+    added for ``RowRange``.
+
 .. currentmodule:: django.db.models
 .. currentmodule:: django.db.models
 
 
 Technical Information
 Technical Information

+ 3 - 0
docs/releases/5.1.txt

@@ -171,6 +171,9 @@ Models
 * :meth:`.QuerySet.explain` now supports the ``generic_plan`` option on
 * :meth:`.QuerySet.explain` now supports the ``generic_plan`` option on
   PostgreSQL 16+.
   PostgreSQL 16+.
 
 
+* :class:`~django.db.models.expressions.RowRange` now accepts positive integers
+  for the ``start`` argument and negative integers for the ``end`` argument.
+
 Requests and Responses
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 ~~~~~~~~~~~~~~~~~~~~~~
 
 

+ 101 - 2
tests/expressions_window/tests.py

@@ -1328,6 +1328,84 @@ class WindowFunctionTests(TestCase):
             ),
             ),
         )
         )
 
 
+    def test_row_range_both_preceding(self):
+        """
+        A query with ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING.
+        The resulting sum is the sum of the previous two (if they exist) rows
+        according to the ordering clause.
+        """
+        qs = Employee.objects.annotate(
+            sum=Window(
+                expression=Sum("salary"),
+                order_by=[F("hire_date").asc(), F("name").desc()],
+                frame=RowRange(start=-2, end=-1),
+            )
+        ).order_by("hire_date")
+        self.assertIn("ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING", str(qs.query))
+        self.assertQuerySetEqual(
+            qs,
+            [
+                ("Miller", 100000, "Management", datetime.date(2005, 6, 1), None),
+                ("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 100000),
+                ("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 180000),
+                ("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 125000),
+                ("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 100000),
+                ("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 100000),
+                ("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 82000),
+                ("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 90000),
+                ("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 91000),
+                ("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 98000),
+                ("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 100000),
+                ("Moore", 34000, "IT", datetime.date(2013, 8, 1), 90000),
+            ],
+            transform=lambda row: (
+                row.name,
+                row.salary,
+                row.department,
+                row.hire_date,
+                row.sum,
+            ),
+        )
+
+    def test_row_range_both_following(self):
+        """
+        A query with ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING.
+        The resulting sum is the sum of the following two (if they exist) rows
+        according to the ordering clause.
+        """
+        qs = Employee.objects.annotate(
+            sum=Window(
+                expression=Sum("salary"),
+                order_by=[F("hire_date").asc(), F("name").desc()],
+                frame=RowRange(start=1, end=2),
+            )
+        ).order_by("hire_date")
+        self.assertIn("ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING", str(qs.query))
+        self.assertQuerySetEqual(
+            qs,
+            [
+                ("Miller", 100000, "Management", datetime.date(2005, 6, 1), 125000),
+                ("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 100000),
+                ("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 100000),
+                ("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 82000),
+                ("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 90000),
+                ("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 91000),
+                ("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 98000),
+                ("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 100000),
+                ("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 90000),
+                ("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 84000),
+                ("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 34000),
+                ("Moore", 34000, "IT", datetime.date(2013, 8, 1), None),
+            ],
+            transform=lambda row: (
+                row.name,
+                row.salary,
+                row.department,
+                row.hire_date,
+                row.sum,
+            ),
+        )
+
     @skipUnlessDBFeature("can_distinct_on_fields")
     @skipUnlessDBFeature("can_distinct_on_fields")
     def test_distinct_window_function(self):
     def test_distinct_window_function(self):
         """
         """
@@ -1479,6 +1557,19 @@ class WindowFunctionTests(TestCase):
                 )
                 )
             )
             )
 
 
+    def test_invalid_start_end_value_for_row_range(self):
+        msg = "start cannot be greater than end."
+        with self.assertRaisesMessage(ValueError, msg):
+            list(
+                Employee.objects.annotate(
+                    test=Window(
+                        expression=Sum("salary"),
+                        order_by=F("hire_date").asc(),
+                        frame=RowRange(start=4, end=-3),
+                    )
+                )
+            )
+
     def test_invalid_type_end_value_range(self):
     def test_invalid_type_end_value_range(self):
         msg = "end argument must be a positive integer, zero, or None, but got 'a'."
         msg = "end argument must be a positive integer, zero, or None, but got 'a'."
         with self.assertRaisesMessage(ValueError, msg):
         with self.assertRaisesMessage(ValueError, msg):
@@ -1505,7 +1596,7 @@ class WindowFunctionTests(TestCase):
             )
             )
 
 
     def test_invalid_type_end_row_range(self):
     def test_invalid_type_end_row_range(self):
-        msg = "end argument must be a positive integer, zero, or None, but got 'a'."
+        msg = "end argument must be an integer, zero, or None, but got 'a'."
         with self.assertRaisesMessage(ValueError, msg):
         with self.assertRaisesMessage(ValueError, msg):
             list(
             list(
                 Employee.objects.annotate(
                 Employee.objects.annotate(
@@ -1551,7 +1642,7 @@ class WindowFunctionTests(TestCase):
             )
             )
 
 
     def test_invalid_type_start_row_range(self):
     def test_invalid_type_start_row_range(self):
-        msg = "start argument must be a negative integer, zero, or None, but got 'a'."
+        msg = "start argument must be an integer, zero, or None, but got 'a'."
         with self.assertRaisesMessage(ValueError, msg):
         with self.assertRaisesMessage(ValueError, msg):
             list(
             list(
                 Employee.objects.annotate(
                 Employee.objects.annotate(
@@ -1636,6 +1727,14 @@ class NonQueryWindowTests(SimpleTestCase):
             repr(RowRange(start=0, end=0)),
             repr(RowRange(start=0, end=0)),
             "<RowRange: ROWS BETWEEN CURRENT ROW AND CURRENT ROW>",
             "<RowRange: ROWS BETWEEN CURRENT ROW AND CURRENT ROW>",
         )
         )
+        self.assertEqual(
+            repr(RowRange(start=-2, end=-1)),
+            "<RowRange: ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING>",
+        )
+        self.assertEqual(
+            repr(RowRange(start=1, end=2)),
+            "<RowRange: ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING>",
+        )
 
 
     def test_empty_group_by_cols(self):
     def test_empty_group_by_cols(self):
         window = Window(expression=Sum("pk"))
         window = Window(expression=Sum("pk"))