Browse Source

Fixed #27473 -- Added DurationField support to Extract.

Daniel Hahler 8 years ago
parent
commit
43a4835edf

+ 11 - 3
django/db/models/functions/datetime.py

@@ -2,7 +2,8 @@ from datetime import datetime
 
 from django.conf import settings
 from django.db.models import (
-    DateField, DateTimeField, IntegerField, TimeField, Transform,
+    DateField, DateTimeField, DurationField, IntegerField, TimeField,
+    Transform,
 )
 from django.db.models.lookups import (
     YearExact, YearGt, YearGte, YearLt, YearLte,
@@ -49,6 +50,10 @@ class Extract(TimezoneMixin, Transform):
             sql = connection.ops.date_extract_sql(self.lookup_name, sql)
         elif isinstance(lhs_output_field, TimeField):
             sql = connection.ops.time_extract_sql(self.lookup_name, sql)
+        elif isinstance(lhs_output_field, DurationField):
+            if not connection.features.has_native_duration_field:
+                raise ValueError('Extract requires native DurationField database support.')
+            sql = connection.ops.time_extract_sql(self.lookup_name, sql)
         else:
             # resolve_expression has already validated the output_field so this
             # assert should never be hit.
@@ -58,8 +63,11 @@ class Extract(TimezoneMixin, Transform):
     def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
         copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
         field = copy.lhs.output_field
-        if not isinstance(field, (DateField, DateTimeField, TimeField)):
-            raise ValueError('Extract input expression must be DateField, DateTimeField, or TimeField.')
+        if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)):
+            raise ValueError(
+                'Extract input expression must be DateField, DateTimeField, '
+                'TimeField, or DurationField.'
+            )
         # Passing dates to functions expecting datetimes is most likely a mistake.
         if type(field) == DateField and copy.lookup_name in ('hour', 'minute', 'second'):
             raise ValueError(

+ 10 - 6
docs/ref/models/database-functions.txt

@@ -331,12 +331,16 @@ We'll be using the following model in examples of each function::
 
 Extracts a component of a date as a number.
 
-Takes an ``expression`` representing a ``DateField`` or ``DateTimeField`` and a
-``lookup_name``, and returns the part of the date referenced by ``lookup_name``
-as an ``IntegerField``. Django usually uses the databases' extract function, so
-you may use any ``lookup_name`` that your database supports. A ``tzinfo``
-subclass, usually provided by ``pytz``, can be passed to extract a value in a
-specific timezone.
+Takes an ``expression`` representing a ``DateField``, ``DateTimeField``,
+``TimeField``, or ``DurationField`` and a ``lookup_name``, and returns the part
+of the date referenced by ``lookup_name`` as an ``IntegerField``.
+Django usually uses the databases' extract function, so you may use any
+``lookup_name`` that your database supports. A ``tzinfo`` subclass, usually
+provided by ``pytz``, can be passed to extract a value in a specific timezone.
+
+.. versionchanged:: 2.0
+
+    Support for ``DurationField`` was added.
 
 Given the datetime ``2015-06-15 23:30:01.000321+00:00``, the built-in
 ``lookup_name``\s return:

+ 4 - 0
docs/releases/2.0.txt

@@ -248,6 +248,10 @@ Models
 * Added the :attr:`~django.db.models.Index.db_tablespace` parameter to
   class-based indexes.
 
+* If the database supports a native duration field (Oracle and PostgreSQL),
+  :class:`~django.db.models.functions.datetime.Extract` now works with
+  :class:`~django.db.models.DurationField`.
+
 Requests and Responses
 ~~~~~~~~~~~~~~~~~~~~~~
 

+ 34 - 2
tests/db_functions/test_datetime.py

@@ -11,7 +11,9 @@ from django.db.models.functions import (
     Trunc, TruncDate, TruncDay, TruncHour, TruncMinute, TruncMonth,
     TruncQuarter, TruncSecond, TruncTime, TruncYear,
 )
-from django.test import TestCase, override_settings
+from django.test import (
+    TestCase, override_settings, skipIfDBFeature, skipUnlessDBFeature,
+)
 from django.utils import timezone
 
 from .models import DTModel
@@ -147,7 +149,7 @@ class DateFunctionTests(TestCase):
         with self.assertRaisesMessage(ValueError, 'lookup_name must be provided'):
             Extract('start_datetime')
 
-        msg = 'Extract input expression must be DateField, DateTimeField, or TimeField.'
+        msg = 'Extract input expression must be DateField, DateTimeField, TimeField, or DurationField.'
         with self.assertRaisesMessage(ValueError, msg):
             list(DTModel.objects.annotate(extracted=Extract('name', 'hour')))
 
@@ -208,6 +210,36 @@ class DateFunctionTests(TestCase):
         self.assertEqual(DTModel.objects.filter(start_date__month=Extract('start_date', 'month')).count(), 2)
         self.assertEqual(DTModel.objects.filter(start_time__hour=Extract('start_time', 'hour')).count(), 2)
 
+    @skipUnlessDBFeature('has_native_duration_field')
+    def test_extract_duration(self):
+        start_datetime = microsecond_support(datetime(2015, 6, 15, 14, 30, 50, 321))
+        end_datetime = microsecond_support(datetime(2016, 6, 15, 14, 10, 50, 123))
+        if settings.USE_TZ:
+            start_datetime = timezone.make_aware(start_datetime, is_dst=False)
+            end_datetime = timezone.make_aware(end_datetime, is_dst=False)
+        self.create_model(start_datetime, end_datetime)
+        self.create_model(end_datetime, start_datetime)
+        self.assertQuerysetEqual(
+            DTModel.objects.annotate(extracted=Extract('duration', 'second')).order_by('start_datetime'),
+            [
+                (start_datetime, (end_datetime - start_datetime).seconds % 60),
+                (end_datetime, (start_datetime - end_datetime).seconds % 60)
+            ],
+            lambda m: (m.start_datetime, m.extracted)
+        )
+        self.assertEqual(
+            DTModel.objects.annotate(
+                duration_days=Extract('duration', 'day'),
+            ).filter(duration_days__gt=200).count(),
+            1
+        )
+
+    @skipIfDBFeature('has_native_duration_field')
+    def test_extract_duration_without_native_duration_field(self):
+        msg = 'Extract requires native DurationField database support.'
+        with self.assertRaisesMessage(ValueError, msg):
+            list(DTModel.objects.annotate(extracted=Extract('duration', 'second')))
+
     def test_extract_year_func(self):
         start_datetime = microsecond_support(datetime(2015, 6, 15, 14, 30, 50, 321))
         end_datetime = microsecond_support(datetime(2016, 6, 15, 14, 10, 50, 123))