Bläddra i källkod

Fixed #32358 -- Fixed queryset crash when grouping by annotation with Distance()/Area().

Made MeasureBase hashable.
Illia Volochii 4 år sedan
förälder
incheckning
bef6f75842
3 ändrade filer med 40 tillägg och 1 borttagningar
  1. 3 0
      django/contrib/gis/measure.py
  2. 21 1
      tests/gis_tests/distapp/tests.py
  3. 16 0
      tests/gis_tests/test_measure.py

+ 3 - 0
django/contrib/gis/measure.py

@@ -89,6 +89,9 @@ class MeasureBase:
         else:
             return NotImplemented
 
+    def __hash__(self):
+        return hash(self.standard)
+
     def __lt__(self, other):
         if isinstance(other, self.__class__):
             return self.standard < other.standard

+ 21 - 1
tests/gis_tests/distapp/tests.py

@@ -4,7 +4,9 @@ from django.contrib.gis.db.models.functions import (
 from django.contrib.gis.geos import GEOSGeometry, LineString, Point
 from django.contrib.gis.measure import D  # alias for Distance
 from django.db import NotSupportedError, connection
-from django.db.models import Exists, F, OuterRef, Q
+from django.db.models import (
+    Case, Count, Exists, F, IntegerField, OuterRef, Q, Value, When,
+)
 from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
 
 from ..utils import FuncTestMixin
@@ -214,6 +216,24 @@ class DistanceTest(TestCase):
             SouthTexasCity.objects.count(),
         )
 
+    @skipUnlessDBFeature('supports_distances_lookups')
+    def test_distance_annotation_group_by(self):
+        stx_pnt = self.stx_pnt.transform(
+            SouthTexasCity._meta.get_field('point').srid,
+            clone=True,
+        )
+        qs = SouthTexasCity.objects.annotate(
+            relative_distance=Case(
+                When(point__distance_lte=(stx_pnt, D(km=20)), then=Value(20)),
+                default=Value(100),
+                output_field=IntegerField(),
+            ),
+        ).values('relative_distance').annotate(count=Count('pk'))
+        self.assertCountEqual(qs, [
+            {'relative_distance': 20, 'count': 5},
+            {'relative_distance': 100, 'count': 4},
+        ])
+
     def test_mysql_geodetic_distance_error(self):
         if not connection.ops.mysql:
             self.skipTest('This is a MySQL-specific test.')

+ 16 - 0
tests/gis_tests/test_measure.py

@@ -151,6 +151,14 @@ class DistanceTest(unittest.TestCase):
             with self.subTest(nm=nm):
                 self.assertEqual(att, D.unit_attname(nm))
 
+    def test_hash(self):
+        d1 = D(m=99)
+        d2 = D(m=100)
+        d3 = D(km=0.1)
+        self.assertEqual(hash(d2), hash(d3))
+        self.assertNotEqual(hash(d1), hash(d2))
+        self.assertNotEqual(hash(d1), hash(d3))
+
 
 class AreaTest(unittest.TestCase):
     "Testing the Area object"
@@ -272,6 +280,14 @@ class AreaTest(unittest.TestCase):
         self.assertEqual(repr(a1), 'Area(sq_m=100.0)')
         self.assertEqual(repr(a2), 'Area(sq_km=3.5)')
 
+    def test_hash(self):
+        a1 = A(sq_m=100)
+        a2 = A(sq_m=1000000)
+        a3 = A(sq_km=1)
+        self.assertEqual(hash(a2), hash(a3))
+        self.assertNotEqual(hash(a1), hash(a2))
+        self.assertNotEqual(hash(a1), hash(a3))
+
 
 def suite():
     s = unittest.TestSuite()