aggregates.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from django.db.models.aggregates import Aggregate
  2. from django.contrib.gis.db.models.fields import GeometryField, ExtentField
  3. __all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union']
  4. class GeoAggregate(Aggregate):
  5. template = None
  6. function = None
  7. is_extent = False
  8. def as_sql(self, compiler, connection):
  9. if connection.ops.oracle:
  10. if not hasattr(self, 'tolerance'):
  11. self.tolerance = 0.05
  12. self.extra['tolerance'] = self.tolerance
  13. template, function = connection.ops.spatial_aggregate_sql(self)
  14. if template is None:
  15. template = '%(function)s(%(expressions)s)'
  16. self.extra['template'] = self.extra.get('template', template)
  17. self.extra['function'] = self.extra.get('function', function)
  18. return super(GeoAggregate, self).as_sql(compiler, connection)
  19. def prepare(self, query=None, allow_joins=True, reuse=None, summarize=False):
  20. c = super(GeoAggregate, self).prepare(query, allow_joins, reuse, summarize)
  21. if not isinstance(self.expressions[0].output_field, GeometryField):
  22. raise ValueError('Geospatial aggregates only allowed on geometry fields.')
  23. return c
  24. def convert_value(self, value, connection, context):
  25. return connection.ops.convert_geom(value, self.output_field)
  26. class Collect(GeoAggregate):
  27. name = 'Collect'
  28. class Extent(GeoAggregate):
  29. name = 'Extent'
  30. is_extent = '2D'
  31. def __init__(self, expression, **extra):
  32. super(Extent, self).__init__(expression, output_field=ExtentField(), **extra)
  33. def convert_value(self, value, connection, context):
  34. return connection.ops.convert_extent(value, context.get('transformed_srid'))
  35. class Extent3D(GeoAggregate):
  36. name = 'Extent3D'
  37. is_extent = '3D'
  38. def __init__(self, expression, **extra):
  39. super(Extent3D, self).__init__(expression, output_field=ExtentField(), **extra)
  40. def convert_value(self, value, connection, context):
  41. return connection.ops.convert_extent3d(value, context.get('transformed_srid'))
  42. class MakeLine(GeoAggregate):
  43. name = 'MakeLine'
  44. class Union(GeoAggregate):
  45. name = 'Union'