123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458 |
- from decimal import Decimal
- from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField
- from django.contrib.gis.db.models.sql import AreaField, DistanceField
- from django.contrib.gis.geos import GEOSGeometry
- from django.core.exceptions import FieldError
- from django.db.models import (
- BooleanField, FloatField, IntegerField, TextField, Transform,
- )
- from django.db.models.expressions import Func, Value
- from django.db.models.functions import Cast
- from django.db.utils import NotSupportedError
- from django.utils.functional import cached_property
- NUMERIC_TYPES = (int, float, Decimal)
- class GeoFuncMixin:
- function = None
- geom_param_pos = (0,)
- def __init__(self, *expressions, **extra):
- super().__init__(*expressions, **extra)
- # Ensure that value expressions are geometric.
- for pos in self.geom_param_pos:
- expr = self.source_expressions[pos]
- if not isinstance(expr, Value):
- continue
- try:
- output_field = expr.output_field
- except FieldError:
- output_field = None
- geom = expr.value
- if not isinstance(geom, GEOSGeometry) or output_field and not isinstance(output_field, GeometryField):
- raise TypeError("%s function requires a geometric argument in position %d." % (self.name, pos + 1))
- if not geom.srid and not output_field:
- raise ValueError("SRID is required for all geometries.")
- if not output_field:
- self.source_expressions[pos] = Value(geom, output_field=GeometryField(srid=geom.srid))
- @property
- def name(self):
- return self.__class__.__name__
- @cached_property
- def geo_field(self):
- return self.source_expressions[self.geom_param_pos[0]].field
- def as_sql(self, compiler, connection, function=None, **extra_context):
- if not self.function and not function:
- function = connection.ops.spatial_function_name(self.name)
- return super().as_sql(compiler, connection, function=function, **extra_context)
- def resolve_expression(self, *args, **kwargs):
- res = super().resolve_expression(*args, **kwargs)
- # Ensure that expressions are geometric.
- source_fields = res.get_source_fields()
- for pos in self.geom_param_pos:
- field = source_fields[pos]
- if not isinstance(field, GeometryField):
- raise TypeError(
- "%s function requires a GeometryField in position %s, got %s." % (
- self.name, pos + 1, type(field).__name__,
- )
- )
- base_srid = res.geo_field.srid
- for pos in self.geom_param_pos[1:]:
- expr = res.source_expressions[pos]
- expr_srid = expr.output_field.srid
- if expr_srid != base_srid:
- # Automatic SRID conversion so objects are comparable.
- res.source_expressions[pos] = Transform(expr, base_srid).resolve_expression(*args, **kwargs)
- return res
- def _handle_param(self, value, param_name='', check_types=None):
- if not hasattr(value, 'resolve_expression'):
- if check_types and not isinstance(value, check_types):
- raise TypeError(
- "The %s parameter has the wrong type: should be %s." % (
- param_name, check_types)
- )
- return value
- class GeoFunc(GeoFuncMixin, Func):
- pass
- class GeomOutputGeoFunc(GeoFunc):
- @cached_property
- def output_field(self):
- return GeometryField(srid=self.geo_field.srid)
- class SQLiteDecimalToFloatMixin:
- """
- By default, Decimal values are converted to str by the SQLite backend, which
- is not acceptable by the GIS functions expecting numeric values.
- """
- def as_sqlite(self, compiler, connection, **extra_context):
- for expr in self.get_source_expressions():
- if hasattr(expr, 'value') and isinstance(expr.value, Decimal):
- expr.value = float(expr.value)
- return super().as_sql(compiler, connection, **extra_context)
- class OracleToleranceMixin:
- tolerance = 0.05
- def as_oracle(self, compiler, connection, **extra_context):
- tol = self.extra.get('tolerance', self.tolerance)
- return self.as_sql(
- compiler, connection,
- template="%%(function)s(%%(expressions)s, %s)" % tol,
- **extra_context
- )
- class Area(OracleToleranceMixin, GeoFunc):
- arity = 1
- @cached_property
- def output_field(self):
- return AreaField(self.geo_field)
- def as_sql(self, compiler, connection, **extra_context):
- if not connection.features.supports_area_geodetic and self.geo_field.geodetic(connection):
- raise NotSupportedError('Area on geodetic coordinate systems not supported.')
- return super().as_sql(compiler, connection, **extra_context)
- def as_sqlite(self, compiler, connection, **extra_context):
- if self.geo_field.geodetic(connection):
- extra_context['template'] = '%(function)s(%(expressions)s, %(spheroid)d)'
- extra_context['spheroid'] = True
- return self.as_sql(compiler, connection, **extra_context)
- class Azimuth(GeoFunc):
- output_field = FloatField()
- arity = 2
- geom_param_pos = (0, 1)
- class AsGeoJSON(GeoFunc):
- output_field = TextField()
- def __init__(self, expression, bbox=False, crs=False, precision=8, **extra):
- expressions = [expression]
- if precision is not None:
- expressions.append(self._handle_param(precision, 'precision', int))
- options = 0
- if crs and bbox:
- options = 3
- elif bbox:
- options = 1
- elif crs:
- options = 2
- if options:
- expressions.append(options)
- super().__init__(*expressions, **extra)
- class AsGML(GeoFunc):
- geom_param_pos = (1,)
- output_field = TextField()
- def __init__(self, expression, version=2, precision=8, **extra):
- expressions = [version, expression]
- if precision is not None:
- expressions.append(self._handle_param(precision, 'precision', int))
- super().__init__(*expressions, **extra)
- def as_oracle(self, compiler, connection, **extra_context):
- source_expressions = self.get_source_expressions()
- version = source_expressions[0]
- clone = self.copy()
- clone.set_source_expressions([source_expressions[1]])
- extra_context['function'] = 'SDO_UTIL.TO_GML311GEOMETRY' if version.value == 3 else 'SDO_UTIL.TO_GMLGEOMETRY'
- return super(AsGML, clone).as_sql(compiler, connection, **extra_context)
- class AsKML(AsGML):
- def as_sqlite(self, compiler, connection, **extra_context):
- # No version parameter
- clone = self.copy()
- clone.set_source_expressions(self.get_source_expressions()[1:])
- return clone.as_sql(compiler, connection, **extra_context)
- class AsSVG(GeoFunc):
- output_field = TextField()
- def __init__(self, expression, relative=False, precision=8, **extra):
- relative = relative if hasattr(relative, 'resolve_expression') else int(relative)
- expressions = [
- expression,
- relative,
- self._handle_param(precision, 'precision', int),
- ]
- super().__init__(*expressions, **extra)
- class BoundingCircle(OracleToleranceMixin, GeoFunc):
- def __init__(self, expression, num_seg=48, **extra):
- super().__init__(expression, num_seg, **extra)
- def as_oracle(self, compiler, connection, **extra_context):
- clone = self.copy()
- clone.set_source_expressions([self.get_source_expressions()[0]])
- return super(BoundingCircle, clone).as_oracle(compiler, connection, **extra_context)
- class Centroid(OracleToleranceMixin, GeomOutputGeoFunc):
- arity = 1
- class Difference(OracleToleranceMixin, GeomOutputGeoFunc):
- arity = 2
- geom_param_pos = (0, 1)
- class DistanceResultMixin:
- @cached_property
- def output_field(self):
- return DistanceField(self.geo_field)
- def source_is_geography(self):
- return self.geo_field.geography and self.geo_field.srid == 4326
- class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
- geom_param_pos = (0, 1)
- spheroid = None
- def __init__(self, expr1, expr2, spheroid=None, **extra):
- expressions = [expr1, expr2]
- if spheroid is not None:
- self.spheroid = self._handle_param(spheroid, 'spheroid', bool)
- super().__init__(*expressions, **extra)
- def as_postgresql(self, compiler, connection, **extra_context):
- clone = self.copy()
- function = None
- expr2 = clone.source_expressions[1]
- geography = self.source_is_geography()
- if expr2.output_field.geography != geography:
- if isinstance(expr2, Value):
- expr2.output_field.geography = geography
- else:
- clone.source_expressions[1] = Cast(
- expr2,
- GeometryField(srid=expr2.output_field.srid, geography=geography),
- )
- if not geography and self.geo_field.geodetic(connection):
- # Geometry fields with geodetic (lon/lat) coordinates need special distance functions
- if self.spheroid:
- # DistanceSpheroid is more accurate and resource intensive than DistanceSphere
- function = connection.ops.spatial_function_name('DistanceSpheroid')
- # Replace boolean param by the real spheroid of the base field
- clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
- else:
- function = connection.ops.spatial_function_name('DistanceSphere')
- return super(Distance, clone).as_sql(compiler, connection, function=function, **extra_context)
- def as_sqlite(self, compiler, connection, **extra_context):
- if self.geo_field.geodetic(connection):
- # SpatiaLite returns NULL instead of zero on geodetic coordinates
- extra_context['template'] = 'COALESCE(%(function)s(%(expressions)s, %(spheroid)s), 0)'
- extra_context['spheroid'] = int(bool(self.spheroid))
- return super().as_sql(compiler, connection, **extra_context)
- class Envelope(GeomOutputGeoFunc):
- arity = 1
- class ForcePolygonCW(GeomOutputGeoFunc):
- arity = 1
- class GeoHash(GeoFunc):
- output_field = TextField()
- def __init__(self, expression, precision=None, **extra):
- expressions = [expression]
- if precision is not None:
- expressions.append(self._handle_param(precision, 'precision', int))
- super().__init__(*expressions, **extra)
- def as_mysql(self, compiler, connection, **extra_context):
- clone = self.copy()
- # If no precision is provided, set it to the maximum.
- if len(clone.source_expressions) < 2:
- clone.source_expressions.append(Value(100))
- return clone.as_sql(compiler, connection, **extra_context)
- class Intersection(OracleToleranceMixin, GeomOutputGeoFunc):
- arity = 2
- geom_param_pos = (0, 1)
- @BaseSpatialField.register_lookup
- class IsValid(OracleToleranceMixin, GeoFuncMixin, Transform):
- lookup_name = 'isvalid'
- output_field = BooleanField()
- def as_oracle(self, compiler, connection, **extra_context):
- sql, params = super().as_oracle(compiler, connection, **extra_context)
- return "CASE %s WHEN 'TRUE' THEN 1 ELSE 0 END" % sql, params
- class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
- def __init__(self, expr1, spheroid=True, **extra):
- self.spheroid = spheroid
- super().__init__(expr1, **extra)
- def as_sql(self, compiler, connection, **extra_context):
- if self.geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
- raise NotSupportedError("This backend doesn't support Length on geodetic fields")
- return super().as_sql(compiler, connection, **extra_context)
- def as_postgresql(self, compiler, connection, **extra_context):
- clone = self.copy()
- function = None
- if self.source_is_geography():
- clone.source_expressions.append(Value(self.spheroid))
- elif self.geo_field.geodetic(connection):
- # Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
- function = connection.ops.spatial_function_name('LengthSpheroid')
- clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
- else:
- dim = min(f.dim for f in self.get_source_fields() if f)
- if dim > 2:
- function = connection.ops.length3d
- return super(Length, clone).as_sql(compiler, connection, function=function, **extra_context)
- def as_sqlite(self, compiler, connection, **extra_context):
- function = None
- if self.geo_field.geodetic(connection):
- function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength'
- return super().as_sql(compiler, connection, function=function, **extra_context)
- class LineLocatePoint(GeoFunc):
- output_field = FloatField()
- arity = 2
- geom_param_pos = (0, 1)
- class MakeValid(GeoFunc):
- pass
- class MemSize(GeoFunc):
- output_field = IntegerField()
- arity = 1
- class NumGeometries(GeoFunc):
- output_field = IntegerField()
- arity = 1
- class NumPoints(GeoFunc):
- output_field = IntegerField()
- arity = 1
- class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
- arity = 1
- def as_postgresql(self, compiler, connection, **extra_context):
- function = None
- if self.geo_field.geodetic(connection) and not self.source_is_geography():
- raise NotSupportedError("ST_Perimeter cannot use a non-projected non-geography field.")
- dim = min(f.dim for f in self.get_source_fields())
- if dim > 2:
- function = connection.ops.perimeter3d
- return super().as_sql(compiler, connection, function=function, **extra_context)
- def as_sqlite(self, compiler, connection, **extra_context):
- if self.geo_field.geodetic(connection):
- raise NotSupportedError("Perimeter cannot use a non-projected field.")
- return super().as_sql(compiler, connection, **extra_context)
- class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc):
- arity = 1
- class Reverse(GeoFunc):
- arity = 1
- class Scale(SQLiteDecimalToFloatMixin, GeomOutputGeoFunc):
- def __init__(self, expression, x, y, z=0.0, **extra):
- expressions = [
- expression,
- self._handle_param(x, 'x', NUMERIC_TYPES),
- self._handle_param(y, 'y', NUMERIC_TYPES),
- ]
- if z != 0.0:
- expressions.append(self._handle_param(z, 'z', NUMERIC_TYPES))
- super().__init__(*expressions, **extra)
- class SnapToGrid(SQLiteDecimalToFloatMixin, GeomOutputGeoFunc):
- def __init__(self, expression, *args, **extra):
- nargs = len(args)
- expressions = [expression]
- if nargs in (1, 2):
- expressions.extend(
- [self._handle_param(arg, '', NUMERIC_TYPES) for arg in args]
- )
- elif nargs == 4:
- # Reverse origin and size param ordering
- expressions += [
- *(self._handle_param(arg, '', NUMERIC_TYPES) for arg in args[2:]),
- *(self._handle_param(arg, '', NUMERIC_TYPES) for arg in args[0:2]),
- ]
- else:
- raise ValueError('Must provide 1, 2, or 4 arguments to `SnapToGrid`.')
- super().__init__(*expressions, **extra)
- class SymDifference(OracleToleranceMixin, GeomOutputGeoFunc):
- arity = 2
- geom_param_pos = (0, 1)
- class Transform(GeomOutputGeoFunc):
- def __init__(self, expression, srid, **extra):
- expressions = [
- expression,
- self._handle_param(srid, 'srid', int),
- ]
- if 'output_field' not in extra:
- extra['output_field'] = GeometryField(srid=srid)
- super().__init__(*expressions, **extra)
- class Translate(Scale):
- def as_sqlite(self, compiler, connection, **extra_context):
- clone = self.copy()
- if len(self.source_expressions) < 4:
- # Always provide the z parameter for ST_Translate
- clone.source_expressions.append(Value(0))
- return super(Translate, clone).as_sqlite(compiler, connection, **extra_context)
- class Union(OracleToleranceMixin, GeomOutputGeoFunc):
- arity = 2
- geom_param_pos = (0, 1)
|