Pārlūkot izejas kodu

Refs #29850 -- Added RowRange support for positive integer start and negative integer end.

Sarah Boyce 1 gadu atpakaļ
vecāks
revīzija
6375cee490

+ 36 - 28
django/db/backends/base/operations.py

@@ -714,42 +714,50 @@ class BaseDatabaseOperations:
             "This backend does not support %s subtraction." % internal_type
         )
 
-    def window_frame_start(self, start):
-        if isinstance(start, int):
-            if start < 0:
-                return "%d %s" % (abs(start), self.PRECEDING)
-            elif start == 0:
+    def window_frame_value(self, value):
+        if isinstance(value, int):
+            if value == 0:
                 return self.CURRENT_ROW
-        elif start is None:
-            return self.UNBOUNDED_PRECEDING
-        raise ValueError(
-            "start argument must be a negative integer, zero, or None, but got '%s'."
-            % start
-        )
-
-    def window_frame_end(self, end):
-        if isinstance(end, int):
-            if end == 0:
-                return self.CURRENT_ROW
-            elif end > 0:
-                return "%d %s" % (end, self.FOLLOWING)
-        elif end is None:
-            return self.UNBOUNDED_FOLLOWING
-        raise ValueError(
-            "end argument must be a positive integer, zero, or None, but got '%s'."
-            % end
-        )
+            elif value < 0:
+                return "%d %s" % (abs(value), self.PRECEDING)
+            else:
+                return "%d %s" % (value, self.FOLLOWING)
 
     def window_frame_rows_start_end(self, start=None, end=None):
         """
         Return SQL for start and end points in an OVER clause window frame.
         """
-        if not self.connection.features.supports_over_clause:
-            raise NotSupportedError("This backend does not support window expressions.")
-        return self.window_frame_start(start), self.window_frame_end(end)
+        if isinstance(start, int) and isinstance(end, int) and start > end:
+            raise ValueError("start cannot be greater than end.")
+        if start is not None and not isinstance(start, int):
+            raise ValueError(
+                f"start argument must be an integer, zero, or None, but got '{start}'."
+            )
+        if end is not None and not isinstance(end, int):
+            raise ValueError(
+                f"end argument must be an integer, zero, or None, but got '{end}'."
+            )
+        start_ = self.window_frame_value(start) or self.UNBOUNDED_PRECEDING
+        end_ = self.window_frame_value(end) or self.UNBOUNDED_FOLLOWING
+        return start_, end_
 
     def window_frame_range_start_end(self, start=None, end=None):
-        start_, end_ = self.window_frame_rows_start_end(start, end)
+        if (start is not None and not isinstance(start, int)) or (
+            isinstance(start, int) and start > 0
+        ):
+            raise ValueError(
+                "start argument must be a negative integer, zero, or None, "
+                "but got '%s'." % start
+            )
+        if (end is not None and not isinstance(end, int)) or (
+            isinstance(end, int) and end < 0
+        ):
+            raise ValueError(
+                "end argument must be a positive integer, zero, or None, but got '%s'."
+                % end
+            )
+        start_ = self.window_frame_value(start) or self.UNBOUNDED_PRECEDING
+        end_ = self.window_frame_value(end) or self.UNBOUNDED_FOLLOWING
         features = self.connection.features
         if features.only_supports_unbounded_with_preceding_and_following and (
             (start and start < 0) or (end and end > 0)

+ 4 - 0
django/db/models/expressions.py

@@ -1895,6 +1895,8 @@ class WindowFrame(Expression):
             start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING)
         elif self.start.value is not None and self.start.value == 0:
             start = connection.ops.CURRENT_ROW
+        elif self.start.value is not None and self.start.value > 0:
+            start = "%d %s" % (self.start.value, connection.ops.FOLLOWING)
         else:
             start = connection.ops.UNBOUNDED_PRECEDING
 
@@ -1902,6 +1904,8 @@ class WindowFrame(Expression):
             end = "%d %s" % (self.end.value, connection.ops.FOLLOWING)
         elif self.end.value is not None and self.end.value == 0:
             end = connection.ops.CURRENT_ROW
+        elif self.end.value is not None and self.end.value < 0:
+            end = "%d %s" % (abs(self.end.value), connection.ops.PRECEDING)
         else:
             end = connection.ops.UNBOUNDED_FOLLOWING
         return self.template % {

+ 11 - 3
docs/ref/models/expressions.txt

@@ -923,9 +923,12 @@ SQL generated by the ORM and is by default ``UNBOUNDED FOLLOWING``. The default
 frame includes all rows from the partition to the last row in the set.
 
 The accepted values for the ``start`` and ``end`` arguments are ``None``, an
-integer, or zero. A negative integer for ``start`` results in ``N preceding``,
-while ``None`` yields ``UNBOUNDED PRECEDING``. For both ``start`` and ``end``,
-zero will return ``CURRENT ROW``. Positive integers are accepted for ``end``.
+integer, or zero. A negative integer for ``start`` results in ``N PRECEDING``,
+while ``None`` yields ``UNBOUNDED PRECEDING``. In ``ROWS`` mode, a positive
+integer can be used for ```start`` resulting in ``N FOLLOWING``. Positive
+integers are accepted for ``end`` and results in ``N FOLLOWING``. In ``ROWS``
+mode, a negative integer can be used for ```end`` resulting in ``N PRECEDING``.
+For both ``start`` and ``end``, zero will return ``CURRENT ROW``.
 
 There's a difference in what ``CURRENT ROW`` includes. When specified in
 ``ROWS`` mode, the frame starts or ends with the current row. When specified in
@@ -970,6 +973,11 @@ released between twelve months before and twelve months after each movie:
     ...     ),
     ... )
 
+.. versionchanged:: 5.1
+
+    Support for positive integer ``start`` and negative integer ``end`` was
+    added for ``RowRange``.
+
 .. currentmodule:: django.db.models
 
 Technical Information

+ 3 - 0
docs/releases/5.1.txt

@@ -171,6 +171,9 @@ Models
 * :meth:`.QuerySet.explain` now supports the ``generic_plan`` option on
   PostgreSQL 16+.
 
+* :class:`~django.db.models.expressions.RowRange` now accepts positive integers
+  for the ``start`` argument and negative integers for the ``end`` argument.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 

+ 101 - 2
tests/expressions_window/tests.py

@@ -1328,6 +1328,84 @@ class WindowFunctionTests(TestCase):
             ),
         )
 
+    def test_row_range_both_preceding(self):
+        """
+        A query with ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING.
+        The resulting sum is the sum of the previous two (if they exist) rows
+        according to the ordering clause.
+        """
+        qs = Employee.objects.annotate(
+            sum=Window(
+                expression=Sum("salary"),
+                order_by=[F("hire_date").asc(), F("name").desc()],
+                frame=RowRange(start=-2, end=-1),
+            )
+        ).order_by("hire_date")
+        self.assertIn("ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING", str(qs.query))
+        self.assertQuerySetEqual(
+            qs,
+            [
+                ("Miller", 100000, "Management", datetime.date(2005, 6, 1), None),
+                ("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 100000),
+                ("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 180000),
+                ("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 125000),
+                ("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 100000),
+                ("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 100000),
+                ("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 82000),
+                ("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 90000),
+                ("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 91000),
+                ("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 98000),
+                ("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 100000),
+                ("Moore", 34000, "IT", datetime.date(2013, 8, 1), 90000),
+            ],
+            transform=lambda row: (
+                row.name,
+                row.salary,
+                row.department,
+                row.hire_date,
+                row.sum,
+            ),
+        )
+
+    def test_row_range_both_following(self):
+        """
+        A query with ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING.
+        The resulting sum is the sum of the following two (if they exist) rows
+        according to the ordering clause.
+        """
+        qs = Employee.objects.annotate(
+            sum=Window(
+                expression=Sum("salary"),
+                order_by=[F("hire_date").asc(), F("name").desc()],
+                frame=RowRange(start=1, end=2),
+            )
+        ).order_by("hire_date")
+        self.assertIn("ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING", str(qs.query))
+        self.assertQuerySetEqual(
+            qs,
+            [
+                ("Miller", 100000, "Management", datetime.date(2005, 6, 1), 125000),
+                ("Johnson", 80000, "Management", datetime.date(2005, 7, 1), 100000),
+                ("Jones", 45000, "Accounting", datetime.date(2005, 11, 1), 100000),
+                ("Smith", 55000, "Sales", datetime.date(2007, 6, 1), 82000),
+                ("Jenson", 45000, "Accounting", datetime.date(2008, 4, 1), 90000),
+                ("Williams", 37000, "Accounting", datetime.date(2009, 6, 1), 91000),
+                ("Brown", 53000, "Sales", datetime.date(2009, 9, 1), 98000),
+                ("Smith", 38000, "Marketing", datetime.date(2009, 10, 1), 100000),
+                ("Wilkinson", 60000, "IT", datetime.date(2011, 3, 1), 90000),
+                ("Johnson", 40000, "Marketing", datetime.date(2012, 3, 1), 84000),
+                ("Adams", 50000, "Accounting", datetime.date(2013, 7, 1), 34000),
+                ("Moore", 34000, "IT", datetime.date(2013, 8, 1), None),
+            ],
+            transform=lambda row: (
+                row.name,
+                row.salary,
+                row.department,
+                row.hire_date,
+                row.sum,
+            ),
+        )
+
     @skipUnlessDBFeature("can_distinct_on_fields")
     def test_distinct_window_function(self):
         """
@@ -1479,6 +1557,19 @@ class WindowFunctionTests(TestCase):
                 )
             )
 
+    def test_invalid_start_end_value_for_row_range(self):
+        msg = "start cannot be greater than end."
+        with self.assertRaisesMessage(ValueError, msg):
+            list(
+                Employee.objects.annotate(
+                    test=Window(
+                        expression=Sum("salary"),
+                        order_by=F("hire_date").asc(),
+                        frame=RowRange(start=4, end=-3),
+                    )
+                )
+            )
+
     def test_invalid_type_end_value_range(self):
         msg = "end argument must be a positive integer, zero, or None, but got 'a'."
         with self.assertRaisesMessage(ValueError, msg):
@@ -1505,7 +1596,7 @@ class WindowFunctionTests(TestCase):
             )
 
     def test_invalid_type_end_row_range(self):
-        msg = "end argument must be a positive integer, zero, or None, but got 'a'."
+        msg = "end argument must be an integer, zero, or None, but got 'a'."
         with self.assertRaisesMessage(ValueError, msg):
             list(
                 Employee.objects.annotate(
@@ -1551,7 +1642,7 @@ class WindowFunctionTests(TestCase):
             )
 
     def test_invalid_type_start_row_range(self):
-        msg = "start argument must be a negative integer, zero, or None, but got 'a'."
+        msg = "start argument must be an integer, zero, or None, but got 'a'."
         with self.assertRaisesMessage(ValueError, msg):
             list(
                 Employee.objects.annotate(
@@ -1636,6 +1727,14 @@ class NonQueryWindowTests(SimpleTestCase):
             repr(RowRange(start=0, end=0)),
             "<RowRange: ROWS BETWEEN CURRENT ROW AND CURRENT ROW>",
         )
+        self.assertEqual(
+            repr(RowRange(start=-2, end=-1)),
+            "<RowRange: ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING>",
+        )
+        self.assertEqual(
+            repr(RowRange(start=1, end=2)),
+            "<RowRange: ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING>",
+        )
 
     def test_empty_group_by_cols(self):
         window = Window(expression=Sum("pk"))