浏览代码

Fixed #28432 -- Allowed geometry expressions to be used with distance lookups.

Distance lookups use the Distance function for decreased code redundancy.
Sergey Fedoseev 7 年之前
父节点
当前提交
38af496b98

+ 5 - 0
django/contrib/gis/db/backends/base/operations.py

@@ -1,3 +1,6 @@
+from django.contrib.gis.db.models.functions import Distance
+
+
 class BaseSpatialOperations:
     # Quick booleans for the type of this spatial backend, and
     # an attribute for the spatial database version tuple (if applicable)
@@ -113,3 +116,5 @@ class BaseSpatialOperations:
 
     def spatial_ref_sys(self):
         raise NotImplementedError('subclasses of BaseSpatialOperations must a provide spatial_ref_sys() method')
+
+    distance_expr_for_lookup = staticmethod(Distance)

+ 0 - 8
django/contrib/gis/db/backends/oracle/operations.py

@@ -26,10 +26,6 @@ class SDOOperator(SpatialOperator):
     sql_template = "%(func)s(%(lhs)s, %(rhs)s) = 'TRUE'"
 
 
-class SDODistance(SpatialOperator):
-    sql_template = "SDO_GEOM.SDO_DISTANCE(%%(lhs)s, %%(rhs)s, %s) %%(op)s %%(value)s" % DEFAULT_TOLERANCE
-
-
 class SDODWithin(SpatialOperator):
     sql_template = "SDO_WITHIN_DISTANCE(%(lhs)s, %(rhs)s, %%s) = 'TRUE'"
 
@@ -104,10 +100,6 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
         'relate': SDORelate(),  # Oracle uses a different syntax, e.g., 'mask=inside+touch'
         'touches': SDOOperator(func='SDO_TOUCH'),
         'within': SDOOperator(func='SDO_INSIDE'),
-        'distance_gt': SDODistance(op='>'),
-        'distance_gte': SDODistance(op='>='),
-        'distance_lt': SDODistance(op='<'),
-        'distance_lte': SDODistance(op='<='),
         'dwithin': SDODWithin(),
     }
 

+ 30 - 24
django/contrib/gis/db/backends/postgis/operations.py

@@ -5,10 +5,12 @@ from django.contrib.gis.db.backends.base.operations import (
     BaseSpatialOperations,
 )
 from django.contrib.gis.db.backends.utils import SpatialOperator
+from django.contrib.gis.db.models import GeometryField, RasterField
 from django.contrib.gis.gdal import GDALRaster
 from django.contrib.gis.measure import Distance
 from django.core.exceptions import ImproperlyConfigured
 from django.db.backends.postgresql.operations import DatabaseOperations
+from django.db.models import Func, Value
 from django.db.utils import ProgrammingError
 from django.utils.functional import cached_property
 from django.utils.version import get_version_tuple
@@ -77,26 +79,18 @@ class PostGISOperator(SpatialOperator):
         return template_params
 
 
-class PostGISDistanceOperator(PostGISOperator):
-    sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
-
-    def as_sql(self, connection, lookup, template_params, sql_params):
-        if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection):
-            template_params = self.check_raster(lookup, template_params)
-            sql_template = self.sql_template
-            if len(lookup.rhs_params) == 2 and lookup.rhs_params[-1] == 'spheroid':
-                template_params.update({
-                    'op': self.op,
-                    'func': connection.ops.spatial_function_name('DistanceSpheroid'),
-                })
-                sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s) %(op)s %(value)s'
-                # Using DistanceSpheroid requires the spheroid of the field as
-                # a parameter.
-                sql_params.insert(1, lookup.lhs.output_field.spheroid(connection))
-            else:
-                template_params.update({'op': self.op, 'func': connection.ops.spatial_function_name('DistanceSphere')})
-            return sql_template % template_params, sql_params
-        return super().as_sql(connection, lookup, template_params, sql_params)
+class ST_Polygon(Func):
+    function = 'ST_Polygon'
+
+    def __init__(self, expr):
+        super().__init__(expr)
+        expr = self.source_expressions[0]
+        if isinstance(expr, Value) and not expr._output_field_or_none:
+            self.source_expressions[0] = Value(expr.value, output_field=RasterField(srid=expr.value.srid))
+
+    @cached_property
+    def output_field(self):
+        return GeometryField(srid=self.source_expressions[0].field.srid)
 
 
 class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
@@ -134,10 +128,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
         'touches': PostGISOperator(func='ST_Touches', raster=BILATERAL),
         'within': PostGISOperator(func='ST_Within', raster=BILATERAL),
         'dwithin': PostGISOperator(func='ST_DWithin', geography=True, raster=BILATERAL),
-        'distance_gt': PostGISDistanceOperator(func='ST_Distance', op='>', geography=True),
-        'distance_gte': PostGISDistanceOperator(func='ST_Distance', op='>=', geography=True),
-        'distance_lt': PostGISDistanceOperator(func='ST_Distance', op='<', geography=True),
-        'distance_lte': PostGISDistanceOperator(func='ST_Distance', op='<=', geography=True),
     }
 
     unsupported_functions = set()
@@ -375,3 +365,19 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
     def parse_raster(self, value):
         """Convert a PostGIS HEX String into a dict readable by GDALRaster."""
         return from_pgraster(value)
+
+    def distance_expr_for_lookup(self, lhs, rhs, **kwargs):
+        return super().distance_expr_for_lookup(
+            self._normalize_distance_lookup_arg(lhs),
+            self._normalize_distance_lookup_arg(rhs),
+            **kwargs
+        )
+
+    @staticmethod
+    def _normalize_distance_lookup_arg(arg):
+        is_raster = (
+            arg.field.geom_type == 'RASTER'
+            if hasattr(arg, 'field') else
+            isinstance(arg, GDALRaster)
+        )
+        return ST_Polygon(arg) if is_raster else arg

+ 0 - 18
django/contrib/gis/db/backends/spatialite/operations.py

@@ -17,20 +17,6 @@ from django.utils.functional import cached_property
 from django.utils.version import get_version_tuple
 
 
-class SpatiaLiteDistanceOperator(SpatialOperator):
-    def as_sql(self, connection, lookup, template_params, sql_params):
-        if lookup.lhs.output_field.geodetic(connection):
-            # SpatiaLite returns NULL instead of zero on geodetic coordinates
-            sql_template = 'COALESCE(%(func)s(%(lhs)s, %(rhs)s, %%s), 0) %(op)s %(value)s'
-            template_params.update({
-                'op': self.op,
-                'func': connection.ops.spatial_function_name('Distance'),
-            })
-            sql_params.insert(1, len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid')
-            return sql_template % template_params, sql_params
-        return super().as_sql(connection, lookup, template_params, sql_params)
-
-
 class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
     name = 'spatialite'
     spatialite = True
@@ -68,10 +54,6 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
         'exact': SpatialOperator(func='Equals'),
         # Distance predicates
         'dwithin': SpatialOperator(func='PtDistWithin'),
-        'distance_gt': SpatiaLiteDistanceOperator(func='Distance', op='>'),
-        'distance_gte': SpatiaLiteDistanceOperator(func='Distance', op='>='),
-        'distance_lt': SpatiaLiteDistanceOperator(func='Distance', op='<'),
-        'distance_lte': SpatiaLiteDistanceOperator(func='Distance', op='<='),
     }
 
     disallowed_aggregates = (aggregates.Extent3D,)

+ 32 - 19
django/contrib/gis/db/models/lookups.py

@@ -305,22 +305,13 @@ class DistanceLookupBase(GISLookup):
         if len(self.rhs_params) > 1 and self.rhs_params[1] != 'spheroid':
             self.process_band_indices()
 
-    def process_rhs(self, compiler, connection):
-        params = [connection.ops.Adapter(self.rhs)]
-        # Getting the distance parameter in the units of the field.
+    def process_distance(self, compiler, connection):
         dist_param = self.rhs_params[0]
-        if hasattr(dist_param, 'resolve_expression'):
-            dist_param = dist_param.resolve_expression(compiler.query)
-            sql, expr_params = compiler.compile(dist_param)
-            self.template_params['value'] = sql
-            params.extend(expr_params)
-        else:
-            params += connection.ops.get_distance(
-                self.lhs.output_field, self.rhs_params,
-                self.lookup_name,
-            )
-        rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, params[0], compiler)
-        return (rhs, params)
+        return (
+            compiler.compile(dist_param.resolve_expression(compiler.query))
+            if hasattr(dist_param, 'resolve_expression') else
+            ('%s', connection.ops.get_distance(self.lhs.output_field, self.rhs_params, self.lookup_name))
+        )
 
 
 @BaseSpatialField.register_lookup
@@ -328,22 +319,44 @@ class DWithinLookup(DistanceLookupBase):
     lookup_name = 'dwithin'
     sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s)'
 
+    def process_rhs(self, compiler, connection):
+        dist_sql, dist_params = self.process_distance(compiler, connection)
+        self.template_params['value'] = dist_sql
+        rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, self.rhs, compiler)
+        return rhs, [connection.ops.Adapter(self.rhs)] + dist_params
+
+
+class DistanceLookupFromFunction(DistanceLookupBase):
+    def as_sql(self, compiler, connection):
+        spheroid = (len(self.rhs_params) == 2 and self.rhs_params[-1] == 'spheroid') or None
+        distance_expr = connection.ops.distance_expr_for_lookup(self.lhs, self.rhs, spheroid=spheroid)
+        sql, params = compiler.compile(distance_expr.resolve_expression(compiler.query))
+        dist_sql, dist_params = self.process_distance(compiler, connection)
+        return (
+            '%(func)s %(op)s %(dist)s' % {'func': sql, 'op': self.op, 'dist': dist_sql},
+            params + dist_params,
+        )
+
 
 @BaseSpatialField.register_lookup
-class DistanceGTLookup(DistanceLookupBase):
+class DistanceGTLookup(DistanceLookupFromFunction):
     lookup_name = 'distance_gt'
+    op = '>'
 
 
 @BaseSpatialField.register_lookup
-class DistanceGTELookup(DistanceLookupBase):
+class DistanceGTELookup(DistanceLookupFromFunction):
     lookup_name = 'distance_gte'
+    op = '>='
 
 
 @BaseSpatialField.register_lookup
-class DistanceLTLookup(DistanceLookupBase):
+class DistanceLTLookup(DistanceLookupFromFunction):
     lookup_name = 'distance_lt'
+    op = '<'
 
 
 @BaseSpatialField.register_lookup
-class DistanceLTELookup(DistanceLookupBase):
+class DistanceLTELookup(DistanceLookupFromFunction):
     lookup_name = 'distance_lte'
+    op = '<='

+ 8 - 1
tests/gis_tests/distapp/tests.py

@@ -1,5 +1,5 @@
 from django.contrib.gis.db.models.functions import (
-    Area, Distance, Length, Perimeter, Transform,
+    Area, Distance, Intersection, Length, Perimeter, Transform,
 )
 from django.contrib.gis.geos import GEOSGeometry, LineString, Point
 from django.contrib.gis.measure import D  # alias for Distance
@@ -206,6 +206,13 @@ class DistanceTest(TestCase):
             ).order_by('name')
             self.assertEqual(self.get_names(qs), ['Canberra', 'Hobart', 'Melbourne'])
 
+        # With a complex geometry expression
+        self.assertFalse(SouthTexasCity.objects.filter(point__distance_gt=(Intersection('point', 'point'), 0)))
+        self.assertEqual(
+            SouthTexasCity.objects.filter(point__distance_lte=(Intersection('point', 'point'), 0)).count(),
+            SouthTexasCity.objects.count(),
+        )
+
 
 '''
 =============================