functions.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. from decimal import Decimal
  2. from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField
  3. from django.contrib.gis.db.models.sql import AreaField, DistanceField
  4. from django.contrib.gis.geos import GEOSGeometry
  5. from django.core.exceptions import FieldError
  6. from django.db.models import (
  7. BooleanField, FloatField, IntegerField, TextField, Transform,
  8. )
  9. from django.db.models.expressions import Func, Value
  10. from django.db.models.functions import Cast
  11. from django.db.utils import NotSupportedError
  12. from django.utils.functional import cached_property
  13. NUMERIC_TYPES = (int, float, Decimal)
  14. class GeoFuncMixin:
  15. function = None
  16. geom_param_pos = (0,)
  17. def __init__(self, *expressions, **extra):
  18. super().__init__(*expressions, **extra)
  19. # Ensure that value expressions are geometric.
  20. for pos in self.geom_param_pos:
  21. expr = self.source_expressions[pos]
  22. if not isinstance(expr, Value):
  23. continue
  24. try:
  25. output_field = expr.output_field
  26. except FieldError:
  27. output_field = None
  28. geom = expr.value
  29. if not isinstance(geom, GEOSGeometry) or output_field and not isinstance(output_field, GeometryField):
  30. raise TypeError("%s function requires a geometric argument in position %d." % (self.name, pos + 1))
  31. if not geom.srid and not output_field:
  32. raise ValueError("SRID is required for all geometries.")
  33. if not output_field:
  34. self.source_expressions[pos] = Value(geom, output_field=GeometryField(srid=geom.srid))
  35. @property
  36. def name(self):
  37. return self.__class__.__name__
  38. @cached_property
  39. def geo_field(self):
  40. return self.source_expressions[self.geom_param_pos[0]].field
  41. def as_sql(self, compiler, connection, function=None, **extra_context):
  42. if not self.function and not function:
  43. function = connection.ops.spatial_function_name(self.name)
  44. return super().as_sql(compiler, connection, function=function, **extra_context)
  45. def resolve_expression(self, *args, **kwargs):
  46. res = super().resolve_expression(*args, **kwargs)
  47. # Ensure that expressions are geometric.
  48. source_fields = res.get_source_fields()
  49. for pos in self.geom_param_pos:
  50. field = source_fields[pos]
  51. if not isinstance(field, GeometryField):
  52. raise TypeError(
  53. "%s function requires a GeometryField in position %s, got %s." % (
  54. self.name, pos + 1, type(field).__name__,
  55. )
  56. )
  57. base_srid = res.geo_field.srid
  58. for pos in self.geom_param_pos[1:]:
  59. expr = res.source_expressions[pos]
  60. expr_srid = expr.output_field.srid
  61. if expr_srid != base_srid:
  62. # Automatic SRID conversion so objects are comparable.
  63. res.source_expressions[pos] = Transform(expr, base_srid).resolve_expression(*args, **kwargs)
  64. return res
  65. def _handle_param(self, value, param_name='', check_types=None):
  66. if not hasattr(value, 'resolve_expression'):
  67. if check_types and not isinstance(value, check_types):
  68. raise TypeError(
  69. "The %s parameter has the wrong type: should be %s." % (
  70. param_name, check_types)
  71. )
  72. return value
  73. class GeoFunc(GeoFuncMixin, Func):
  74. pass
  75. class GeomOutputGeoFunc(GeoFunc):
  76. @cached_property
  77. def output_field(self):
  78. return GeometryField(srid=self.geo_field.srid)
  79. class SQLiteDecimalToFloatMixin:
  80. """
  81. By default, Decimal values are converted to str by the SQLite backend, which
  82. is not acceptable by the GIS functions expecting numeric values.
  83. """
  84. def as_sqlite(self, compiler, connection, **extra_context):
  85. for expr in self.get_source_expressions():
  86. if hasattr(expr, 'value') and isinstance(expr.value, Decimal):
  87. expr.value = float(expr.value)
  88. return super().as_sql(compiler, connection, **extra_context)
  89. class OracleToleranceMixin:
  90. tolerance = 0.05
  91. def as_oracle(self, compiler, connection, **extra_context):
  92. tol = self.extra.get('tolerance', self.tolerance)
  93. return self.as_sql(
  94. compiler, connection,
  95. template="%%(function)s(%%(expressions)s, %s)" % tol,
  96. **extra_context
  97. )
  98. class Area(OracleToleranceMixin, GeoFunc):
  99. arity = 1
  100. @cached_property
  101. def output_field(self):
  102. return AreaField(self.geo_field)
  103. def as_sql(self, compiler, connection, **extra_context):
  104. if not connection.features.supports_area_geodetic and self.geo_field.geodetic(connection):
  105. raise NotSupportedError('Area on geodetic coordinate systems not supported.')
  106. return super().as_sql(compiler, connection, **extra_context)
  107. def as_sqlite(self, compiler, connection, **extra_context):
  108. if self.geo_field.geodetic(connection):
  109. extra_context['template'] = '%(function)s(%(expressions)s, %(spheroid)d)'
  110. extra_context['spheroid'] = True
  111. return self.as_sql(compiler, connection, **extra_context)
  112. class Azimuth(GeoFunc):
  113. output_field = FloatField()
  114. arity = 2
  115. geom_param_pos = (0, 1)
  116. class AsGeoJSON(GeoFunc):
  117. output_field = TextField()
  118. def __init__(self, expression, bbox=False, crs=False, precision=8, **extra):
  119. expressions = [expression]
  120. if precision is not None:
  121. expressions.append(self._handle_param(precision, 'precision', int))
  122. options = 0
  123. if crs and bbox:
  124. options = 3
  125. elif bbox:
  126. options = 1
  127. elif crs:
  128. options = 2
  129. if options:
  130. expressions.append(options)
  131. super().__init__(*expressions, **extra)
  132. class AsGML(GeoFunc):
  133. geom_param_pos = (1,)
  134. output_field = TextField()
  135. def __init__(self, expression, version=2, precision=8, **extra):
  136. expressions = [version, expression]
  137. if precision is not None:
  138. expressions.append(self._handle_param(precision, 'precision', int))
  139. super().__init__(*expressions, **extra)
  140. def as_oracle(self, compiler, connection, **extra_context):
  141. source_expressions = self.get_source_expressions()
  142. version = source_expressions[0]
  143. clone = self.copy()
  144. clone.set_source_expressions([source_expressions[1]])
  145. extra_context['function'] = 'SDO_UTIL.TO_GML311GEOMETRY' if version.value == 3 else 'SDO_UTIL.TO_GMLGEOMETRY'
  146. return super(AsGML, clone).as_sql(compiler, connection, **extra_context)
  147. class AsKML(AsGML):
  148. def as_sqlite(self, compiler, connection, **extra_context):
  149. # No version parameter
  150. clone = self.copy()
  151. clone.set_source_expressions(self.get_source_expressions()[1:])
  152. return clone.as_sql(compiler, connection, **extra_context)
  153. class AsSVG(GeoFunc):
  154. output_field = TextField()
  155. def __init__(self, expression, relative=False, precision=8, **extra):
  156. relative = relative if hasattr(relative, 'resolve_expression') else int(relative)
  157. expressions = [
  158. expression,
  159. relative,
  160. self._handle_param(precision, 'precision', int),
  161. ]
  162. super().__init__(*expressions, **extra)
  163. class BoundingCircle(OracleToleranceMixin, GeoFunc):
  164. def __init__(self, expression, num_seg=48, **extra):
  165. super().__init__(expression, num_seg, **extra)
  166. def as_oracle(self, compiler, connection, **extra_context):
  167. clone = self.copy()
  168. clone.set_source_expressions([self.get_source_expressions()[0]])
  169. return super(BoundingCircle, clone).as_oracle(compiler, connection, **extra_context)
  170. class Centroid(OracleToleranceMixin, GeomOutputGeoFunc):
  171. arity = 1
  172. class Difference(OracleToleranceMixin, GeomOutputGeoFunc):
  173. arity = 2
  174. geom_param_pos = (0, 1)
  175. class DistanceResultMixin:
  176. @cached_property
  177. def output_field(self):
  178. return DistanceField(self.geo_field)
  179. def source_is_geography(self):
  180. return self.geo_field.geography and self.geo_field.srid == 4326
  181. class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
  182. geom_param_pos = (0, 1)
  183. spheroid = None
  184. def __init__(self, expr1, expr2, spheroid=None, **extra):
  185. expressions = [expr1, expr2]
  186. if spheroid is not None:
  187. self.spheroid = self._handle_param(spheroid, 'spheroid', bool)
  188. super().__init__(*expressions, **extra)
  189. def as_postgresql(self, compiler, connection, **extra_context):
  190. clone = self.copy()
  191. function = None
  192. expr2 = clone.source_expressions[1]
  193. geography = self.source_is_geography()
  194. if expr2.output_field.geography != geography:
  195. if isinstance(expr2, Value):
  196. expr2.output_field.geography = geography
  197. else:
  198. clone.source_expressions[1] = Cast(
  199. expr2,
  200. GeometryField(srid=expr2.output_field.srid, geography=geography),
  201. )
  202. if not geography and self.geo_field.geodetic(connection):
  203. # Geometry fields with geodetic (lon/lat) coordinates need special distance functions
  204. if self.spheroid:
  205. # DistanceSpheroid is more accurate and resource intensive than DistanceSphere
  206. function = connection.ops.spatial_function_name('DistanceSpheroid')
  207. # Replace boolean param by the real spheroid of the base field
  208. clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
  209. else:
  210. function = connection.ops.spatial_function_name('DistanceSphere')
  211. return super(Distance, clone).as_sql(compiler, connection, function=function, **extra_context)
  212. def as_sqlite(self, compiler, connection, **extra_context):
  213. if self.geo_field.geodetic(connection):
  214. # SpatiaLite returns NULL instead of zero on geodetic coordinates
  215. extra_context['template'] = 'COALESCE(%(function)s(%(expressions)s, %(spheroid)s), 0)'
  216. extra_context['spheroid'] = int(bool(self.spheroid))
  217. return super().as_sql(compiler, connection, **extra_context)
  218. class Envelope(GeomOutputGeoFunc):
  219. arity = 1
  220. class ForcePolygonCW(GeomOutputGeoFunc):
  221. arity = 1
  222. class GeoHash(GeoFunc):
  223. output_field = TextField()
  224. def __init__(self, expression, precision=None, **extra):
  225. expressions = [expression]
  226. if precision is not None:
  227. expressions.append(self._handle_param(precision, 'precision', int))
  228. super().__init__(*expressions, **extra)
  229. def as_mysql(self, compiler, connection, **extra_context):
  230. clone = self.copy()
  231. # If no precision is provided, set it to the maximum.
  232. if len(clone.source_expressions) < 2:
  233. clone.source_expressions.append(Value(100))
  234. return clone.as_sql(compiler, connection, **extra_context)
  235. class Intersection(OracleToleranceMixin, GeomOutputGeoFunc):
  236. arity = 2
  237. geom_param_pos = (0, 1)
  238. @BaseSpatialField.register_lookup
  239. class IsValid(OracleToleranceMixin, GeoFuncMixin, Transform):
  240. lookup_name = 'isvalid'
  241. output_field = BooleanField()
  242. def as_oracle(self, compiler, connection, **extra_context):
  243. sql, params = super().as_oracle(compiler, connection, **extra_context)
  244. return "CASE %s WHEN 'TRUE' THEN 1 ELSE 0 END" % sql, params
  245. class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
  246. def __init__(self, expr1, spheroid=True, **extra):
  247. self.spheroid = spheroid
  248. super().__init__(expr1, **extra)
  249. def as_sql(self, compiler, connection, **extra_context):
  250. if self.geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
  251. raise NotSupportedError("This backend doesn't support Length on geodetic fields")
  252. return super().as_sql(compiler, connection, **extra_context)
  253. def as_postgresql(self, compiler, connection, **extra_context):
  254. clone = self.copy()
  255. function = None
  256. if self.source_is_geography():
  257. clone.source_expressions.append(Value(self.spheroid))
  258. elif self.geo_field.geodetic(connection):
  259. # Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
  260. function = connection.ops.spatial_function_name('LengthSpheroid')
  261. clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
  262. else:
  263. dim = min(f.dim for f in self.get_source_fields() if f)
  264. if dim > 2:
  265. function = connection.ops.length3d
  266. return super(Length, clone).as_sql(compiler, connection, function=function, **extra_context)
  267. def as_sqlite(self, compiler, connection, **extra_context):
  268. function = None
  269. if self.geo_field.geodetic(connection):
  270. function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength'
  271. return super().as_sql(compiler, connection, function=function, **extra_context)
  272. class LineLocatePoint(GeoFunc):
  273. output_field = FloatField()
  274. arity = 2
  275. geom_param_pos = (0, 1)
  276. class MakeValid(GeoFunc):
  277. pass
  278. class MemSize(GeoFunc):
  279. output_field = IntegerField()
  280. arity = 1
  281. class NumGeometries(GeoFunc):
  282. output_field = IntegerField()
  283. arity = 1
  284. class NumPoints(GeoFunc):
  285. output_field = IntegerField()
  286. arity = 1
  287. class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
  288. arity = 1
  289. def as_postgresql(self, compiler, connection, **extra_context):
  290. function = None
  291. if self.geo_field.geodetic(connection) and not self.source_is_geography():
  292. raise NotSupportedError("ST_Perimeter cannot use a non-projected non-geography field.")
  293. dim = min(f.dim for f in self.get_source_fields())
  294. if dim > 2:
  295. function = connection.ops.perimeter3d
  296. return super().as_sql(compiler, connection, function=function, **extra_context)
  297. def as_sqlite(self, compiler, connection, **extra_context):
  298. if self.geo_field.geodetic(connection):
  299. raise NotSupportedError("Perimeter cannot use a non-projected field.")
  300. return super().as_sql(compiler, connection, **extra_context)
  301. class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc):
  302. arity = 1
  303. class Reverse(GeoFunc):
  304. arity = 1
  305. class Scale(SQLiteDecimalToFloatMixin, GeomOutputGeoFunc):
  306. def __init__(self, expression, x, y, z=0.0, **extra):
  307. expressions = [
  308. expression,
  309. self._handle_param(x, 'x', NUMERIC_TYPES),
  310. self._handle_param(y, 'y', NUMERIC_TYPES),
  311. ]
  312. if z != 0.0:
  313. expressions.append(self._handle_param(z, 'z', NUMERIC_TYPES))
  314. super().__init__(*expressions, **extra)
  315. class SnapToGrid(SQLiteDecimalToFloatMixin, GeomOutputGeoFunc):
  316. def __init__(self, expression, *args, **extra):
  317. nargs = len(args)
  318. expressions = [expression]
  319. if nargs in (1, 2):
  320. expressions.extend(
  321. [self._handle_param(arg, '', NUMERIC_TYPES) for arg in args]
  322. )
  323. elif nargs == 4:
  324. # Reverse origin and size param ordering
  325. expressions += [
  326. *(self._handle_param(arg, '', NUMERIC_TYPES) for arg in args[2:]),
  327. *(self._handle_param(arg, '', NUMERIC_TYPES) for arg in args[0:2]),
  328. ]
  329. else:
  330. raise ValueError('Must provide 1, 2, or 4 arguments to `SnapToGrid`.')
  331. super().__init__(*expressions, **extra)
  332. class SymDifference(OracleToleranceMixin, GeomOutputGeoFunc):
  333. arity = 2
  334. geom_param_pos = (0, 1)
  335. class Transform(GeomOutputGeoFunc):
  336. def __init__(self, expression, srid, **extra):
  337. expressions = [
  338. expression,
  339. self._handle_param(srid, 'srid', int),
  340. ]
  341. if 'output_field' not in extra:
  342. extra['output_field'] = GeometryField(srid=srid)
  343. super().__init__(*expressions, **extra)
  344. class Translate(Scale):
  345. def as_sqlite(self, compiler, connection, **extra_context):
  346. clone = self.copy()
  347. if len(self.source_expressions) < 4:
  348. # Always provide the z parameter for ST_Translate
  349. clone.source_expressions.append(Value(0))
  350. return super(Translate, clone).as_sqlite(compiler, connection, **extra_context)
  351. class Union(OracleToleranceMixin, GeomOutputGeoFunc):
  352. arity = 2
  353. geom_param_pos = (0, 1)