tests.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. from __future__ import unicode_literals
  2. import os
  3. import re
  4. from django.contrib.gis.gdal import HAS_GDAL
  5. from django.core.management import call_command
  6. from django.db import connection, connections
  7. from django.test import TestCase, skipUnlessDBFeature
  8. from django.test.utils import modify_settings
  9. from django.utils.six import StringIO
  10. from ..test_data import TEST_DATA
  11. from ..utils import postgis
  12. if HAS_GDAL:
  13. from django.contrib.gis.gdal import Driver, GDALException, GDAL_VERSION
  14. from django.contrib.gis.utils.ogrinspect import ogrinspect
  15. from .models import AllOGRFields
  16. @skipUnlessDBFeature("gis_enabled")
  17. class InspectDbTests(TestCase):
  18. def test_geom_columns(self):
  19. """
  20. Test the geo-enabled inspectdb command.
  21. """
  22. out = StringIO()
  23. call_command(
  24. 'inspectdb',
  25. table_name_filter=lambda tn: tn == 'inspectapp_allogrfields',
  26. stdout=out
  27. )
  28. output = out.getvalue()
  29. if connection.features.supports_geometry_field_introspection:
  30. self.assertIn('geom = models.PolygonField()', output)
  31. self.assertIn('point = models.PointField()', output)
  32. else:
  33. self.assertIn('geom = models.GeometryField(', output)
  34. self.assertIn('point = models.GeometryField(', output)
  35. @skipUnlessDBFeature("supports_3d_storage")
  36. def test_3d_columns(self):
  37. out = StringIO()
  38. call_command(
  39. 'inspectdb',
  40. table_name_filter=lambda tn: tn == 'inspectapp_fields3d',
  41. stdout=out
  42. )
  43. output = out.getvalue()
  44. if connection.features.supports_geometry_field_introspection:
  45. self.assertIn('point = models.PointField(dim=3)', output)
  46. if postgis:
  47. # Geography type is specific to PostGIS
  48. self.assertIn('pointg = models.PointField(geography=True, dim=3)', output)
  49. self.assertIn('line = models.LineStringField(dim=3)', output)
  50. self.assertIn('poly = models.PolygonField(dim=3)', output)
  51. else:
  52. self.assertIn('point = models.GeometryField(', output)
  53. self.assertIn('pointg = models.GeometryField(', output)
  54. self.assertIn('line = models.GeometryField(', output)
  55. self.assertIn('poly = models.GeometryField(', output)
  56. @skipUnlessDBFeature("gis_enabled")
  57. @modify_settings(
  58. INSTALLED_APPS={'append': 'django.contrib.gis'},
  59. )
  60. class OGRInspectTest(TestCase):
  61. maxDiff = 1024
  62. def test_poly(self):
  63. shp_file = os.path.join(TEST_DATA, 'test_poly', 'test_poly.shp')
  64. model_def = ogrinspect(shp_file, 'MyModel')
  65. expected = [
  66. '# This is an auto-generated Django model module created by ogrinspect.',
  67. 'from django.contrib.gis.db import models',
  68. '',
  69. 'class MyModel(models.Model):',
  70. ' float = models.FloatField()',
  71. ' int = models.{}()'.format('BigIntegerField' if GDAL_VERSION >= (2, 0) else 'FloatField'),
  72. ' str = models.CharField(max_length=80)',
  73. ' geom = models.PolygonField(srid=-1)',
  74. ]
  75. self.assertEqual(model_def, '\n'.join(expected))
  76. def test_poly_multi(self):
  77. shp_file = os.path.join(TEST_DATA, 'test_poly', 'test_poly.shp')
  78. model_def = ogrinspect(shp_file, 'MyModel', multi_geom=True)
  79. self.assertIn('geom = models.MultiPolygonField(srid=-1)', model_def)
  80. # Same test with a 25D-type geometry field
  81. shp_file = os.path.join(TEST_DATA, 'gas_lines', 'gas_leitung.shp')
  82. model_def = ogrinspect(shp_file, 'MyModel', multi_geom=True)
  83. self.assertIn('geom = models.MultiLineStringField(srid=-1)', model_def)
  84. def test_date_field(self):
  85. shp_file = os.path.join(TEST_DATA, 'cities', 'cities.shp')
  86. model_def = ogrinspect(shp_file, 'City')
  87. expected = [
  88. '# This is an auto-generated Django model module created by ogrinspect.',
  89. 'from django.contrib.gis.db import models',
  90. '',
  91. 'class City(models.Model):',
  92. ' name = models.CharField(max_length=80)',
  93. ' population = models.{}()'.format('BigIntegerField' if GDAL_VERSION >= (2, 0) else 'FloatField'),
  94. ' density = models.FloatField()',
  95. ' created = models.DateField()',
  96. ' geom = models.PointField(srid=-1)',
  97. ]
  98. self.assertEqual(model_def, '\n'.join(expected))
  99. def test_time_field(self):
  100. # Getting the database identifier used by OGR, if None returned
  101. # GDAL does not have the support compiled in.
  102. ogr_db = get_ogr_db_string()
  103. if not ogr_db:
  104. self.skipTest("Unable to setup an OGR connection to your database")
  105. try:
  106. # Writing shapefiles via GDAL currently does not support writing OGRTime
  107. # fields, so we need to actually use a database
  108. model_def = ogrinspect(ogr_db, 'Measurement',
  109. layer_key=AllOGRFields._meta.db_table,
  110. decimal=['f_decimal'])
  111. except GDALException:
  112. self.skipTest("Unable to setup an OGR connection to your database")
  113. self.assertTrue(model_def.startswith(
  114. '# This is an auto-generated Django model module created by ogrinspect.\n'
  115. 'from django.contrib.gis.db import models\n'
  116. '\n'
  117. 'class Measurement(models.Model):\n'
  118. ))
  119. # The ordering of model fields might vary depending on several factors (version of GDAL, etc.)
  120. self.assertIn(' f_decimal = models.DecimalField(max_digits=0, decimal_places=0)', model_def)
  121. self.assertIn(' f_int = models.IntegerField()', model_def)
  122. self.assertIn(' f_datetime = models.DateTimeField()', model_def)
  123. self.assertIn(' f_time = models.TimeField()', model_def)
  124. self.assertIn(' f_float = models.FloatField()', model_def)
  125. self.assertIn(' f_char = models.CharField(max_length=10)', model_def)
  126. self.assertIn(' f_date = models.DateField()', model_def)
  127. # Some backends may have srid=-1
  128. self.assertIsNotNone(re.search(r' geom = models.PolygonField\(([^\)])*\)', model_def))
  129. def test_management_command(self):
  130. shp_file = os.path.join(TEST_DATA, 'cities', 'cities.shp')
  131. out = StringIO()
  132. call_command('ogrinspect', shp_file, 'City', stdout=out)
  133. output = out.getvalue()
  134. self.assertIn('class City(models.Model):', output)
  135. def get_ogr_db_string():
  136. """
  137. Construct the DB string that GDAL will use to inspect the database.
  138. GDAL will create its own connection to the database, so we re-use the
  139. connection settings from the Django test.
  140. """
  141. db = connections.databases['default']
  142. # Map from the django backend into the OGR driver name and database identifier
  143. # http://www.gdal.org/ogr/ogr_formats.html
  144. #
  145. # TODO: Support Oracle (OCI).
  146. drivers = {
  147. 'django.contrib.gis.db.backends.postgis': ('PostgreSQL', "PG:dbname='%(db_name)s'", ' '),
  148. 'django.contrib.gis.db.backends.mysql': ('MySQL', 'MYSQL:"%(db_name)s"', ','),
  149. 'django.contrib.gis.db.backends.spatialite': ('SQLite', '%(db_name)s', '')
  150. }
  151. db_engine = db['ENGINE']
  152. if db_engine not in drivers:
  153. return None
  154. drv_name, db_str, param_sep = drivers[db_engine]
  155. # Ensure that GDAL library has driver support for the database.
  156. try:
  157. Driver(drv_name)
  158. except GDALException:
  159. return None
  160. # SQLite/SpatiaLite in-memory databases
  161. if db['NAME'] == ":memory:":
  162. return None
  163. # Build the params of the OGR database connection string
  164. params = [db_str % {'db_name': db['NAME']}]
  165. def add(key, template):
  166. value = db.get(key, None)
  167. # Don't add the parameter if it is not in django's settings
  168. if value:
  169. params.append(template % value)
  170. add('HOST', "host='%s'")
  171. add('PORT', "port='%s'")
  172. add('USER', "user='%s'")
  173. add('PASSWORD', "password='%s'")
  174. return param_sep.join(params)