Sfoglia il codice sorgente

Fixed #29724 -- Fixed timezone handling in ModelAdmin.date_hierarchy queries.

Thanks Alexander Holmbäck for the initial patch.
Hasan Ramezani 5 anni fa
parent
commit
55cdf6c52d

+ 18 - 5
django/contrib/admin/templatetags/admin_list.py

@@ -2,7 +2,8 @@ import datetime
 
 from django.contrib.admin.templatetags.admin_urls import add_preserved_filters
 from django.contrib.admin.utils import (
-    display_for_field, display_for_value, label_for_field, lookup_field,
+    display_for_field, display_for_value, get_fields_from_path,
+    label_for_field, lookup_field,
 )
 from django.contrib.admin.views.main import (
     ALL_VAR, ORDER_VAR, PAGE_VAR, SEARCH_VAR,
@@ -13,7 +14,7 @@ from django.template import Library
 from django.template.loader import get_template
 from django.templatetags.static import static
 from django.urls import NoReverseMatch
-from django.utils import formats
+from django.utils import formats, timezone
 from django.utils.html import format_html
 from django.utils.safestring import mark_safe
 from django.utils.text import capfirst
@@ -359,6 +360,13 @@ def date_hierarchy(cl):
     """
     if cl.date_hierarchy:
         field_name = cl.date_hierarchy
+        field = get_fields_from_path(cl.model, field_name)[-1]
+        if isinstance(field, models.DateTimeField):
+            dates_or_datetimes = 'datetimes'
+            qs_kwargs = {'is_dst': True}
+        else:
+            dates_or_datetimes = 'dates'
+            qs_kwargs = {}
         year_field = '%s__year' % field_name
         month_field = '%s__month' % field_name
         day_field = '%s__day' % field_name
@@ -374,6 +382,11 @@ def date_hierarchy(cl):
             # select appropriate start level
             date_range = cl.queryset.aggregate(first=models.Min(field_name),
                                                last=models.Max(field_name))
+            if dates_or_datetimes == 'datetimes':
+                date_range = {
+                    k: timezone.localtime(v) if timezone.is_aware(v) else v
+                    for k, v in date_range.items()
+                }
             if date_range['first'] and date_range['last']:
                 if date_range['first'].year == date_range['last'].year:
                     year_lookup = date_range['first'].year
@@ -391,7 +404,7 @@ def date_hierarchy(cl):
                 'choices': [{'title': capfirst(formats.date_format(day, 'MONTH_DAY_FORMAT'))}]
             }
         elif year_lookup and month_lookup:
-            days = getattr(cl.queryset, 'dates')(field_name, 'day')
+            days = getattr(cl.queryset, dates_or_datetimes)(field_name, 'day', **qs_kwargs)
             return {
                 'show': True,
                 'back': {
@@ -404,7 +417,7 @@ def date_hierarchy(cl):
                 } for day in days]
             }
         elif year_lookup:
-            months = getattr(cl.queryset, 'dates')(field_name, 'month')
+            months = getattr(cl.queryset, dates_or_datetimes)(field_name, 'month', **qs_kwargs)
             return {
                 'show': True,
                 'back': {
@@ -417,7 +430,7 @@ def date_hierarchy(cl):
                 } for month in months]
             }
         else:
-            years = getattr(cl.queryset, 'dates')(field_name, 'year')
+            years = getattr(cl.queryset, dates_or_datetimes)(field_name, 'year', **qs_kwargs)
             return {
                 'show': True,
                 'back': None,

+ 12 - 0
tests/admin_views/tests.py

@@ -981,6 +981,18 @@ class AdminViewBasicTest(AdminViewBasicTestCase):
         self.assertContains(response, 'question__expires__month=10')
         self.assertContains(response, 'question__expires__year=2016')
 
+    @override_settings(TIME_ZONE='America/Los_Angeles', USE_TZ=True)
+    def test_date_hierarchy_local_date_differ_from_utc(self):
+        # This datetime is 2017-01-01 in UTC.
+        date = pytz.timezone('America/Los_Angeles').localize(datetime.datetime(2016, 12, 31, 16))
+        q = Question.objects.create(question='Why?', expires=date)
+        Answer2.objects.create(question=q, answer='Because.')
+        response = self.client.get(reverse('admin:admin_views_answer2_changelist'))
+        self.assertEqual(response.status_code, 200)
+        self.assertContains(response, 'question__expires__day=31')
+        self.assertContains(response, 'question__expires__month=12')
+        self.assertContains(response, 'question__expires__year=2016')
+
     def test_sortable_by_columns_subset(self):
         expected_sortable_fields = ('date', 'callable_year')
         expected_not_sortable_fields = (