tests.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. from __future__ import unicode_literals
  2. import os
  3. import re
  4. from django.contrib.gis.gdal import HAS_GDAL
  5. from django.core.management import call_command
  6. from django.db import connection, connections
  7. from django.test import TestCase, skipUnlessDBFeature
  8. from django.test.utils import modify_settings
  9. from django.utils.six import StringIO
  10. from ..test_data import TEST_DATA
  11. if HAS_GDAL:
  12. from django.contrib.gis.gdal import Driver, GDALException, GDAL_VERSION
  13. from django.contrib.gis.utils.ogrinspect import ogrinspect
  14. from .models import AllOGRFields
  15. @skipUnlessDBFeature("gis_enabled")
  16. class InspectDbTests(TestCase):
  17. def test_geom_columns(self):
  18. """
  19. Test the geo-enabled inspectdb command.
  20. """
  21. out = StringIO()
  22. call_command(
  23. 'inspectdb',
  24. table_name_filter=lambda tn: tn == 'inspectapp_allogrfields',
  25. stdout=out
  26. )
  27. output = out.getvalue()
  28. if connection.features.supports_geometry_field_introspection:
  29. self.assertIn('geom = models.PolygonField()', output)
  30. self.assertIn('point = models.PointField()', output)
  31. else:
  32. self.assertIn('geom = models.GeometryField(', output)
  33. self.assertIn('point = models.GeometryField(', output)
  34. @skipUnlessDBFeature("supports_3d_storage")
  35. def test_3d_columns(self):
  36. out = StringIO()
  37. call_command(
  38. 'inspectdb',
  39. table_name_filter=lambda tn: tn == 'inspectapp_fields3d',
  40. stdout=out
  41. )
  42. output = out.getvalue()
  43. if connection.features.supports_geometry_field_introspection:
  44. self.assertIn('point = models.PointField(dim=3)', output)
  45. self.assertIn('line = models.LineStringField(dim=3)', output)
  46. self.assertIn('poly = models.PolygonField(dim=3)', output)
  47. else:
  48. self.assertIn('point = models.GeometryField(', output)
  49. self.assertIn('line = models.GeometryField(', output)
  50. self.assertIn('poly = models.GeometryField(', output)
  51. @skipUnlessDBFeature("gis_enabled")
  52. @modify_settings(
  53. INSTALLED_APPS={'append': 'django.contrib.gis'},
  54. )
  55. class OGRInspectTest(TestCase):
  56. maxDiff = 1024
  57. def test_poly(self):
  58. shp_file = os.path.join(TEST_DATA, 'test_poly', 'test_poly.shp')
  59. model_def = ogrinspect(shp_file, 'MyModel')
  60. expected = [
  61. '# This is an auto-generated Django model module created by ogrinspect.',
  62. 'from django.contrib.gis.db import models',
  63. '',
  64. 'class MyModel(models.Model):',
  65. ' float = models.FloatField()',
  66. ' int = models.{}()'.format('BigIntegerField' if GDAL_VERSION >= (2, 0) else 'FloatField'),
  67. ' str = models.CharField(max_length=80)',
  68. ' geom = models.PolygonField(srid=-1)',
  69. ]
  70. self.assertEqual(model_def, '\n'.join(expected))
  71. def test_poly_multi(self):
  72. shp_file = os.path.join(TEST_DATA, 'test_poly', 'test_poly.shp')
  73. model_def = ogrinspect(shp_file, 'MyModel', multi_geom=True)
  74. self.assertIn('geom = models.MultiPolygonField(srid=-1)', model_def)
  75. # Same test with a 25D-type geometry field
  76. shp_file = os.path.join(TEST_DATA, 'gas_lines', 'gas_leitung.shp')
  77. model_def = ogrinspect(shp_file, 'MyModel', multi_geom=True)
  78. self.assertIn('geom = models.MultiLineStringField(srid=-1)', model_def)
  79. def test_date_field(self):
  80. shp_file = os.path.join(TEST_DATA, 'cities', 'cities.shp')
  81. model_def = ogrinspect(shp_file, 'City')
  82. expected = [
  83. '# This is an auto-generated Django model module created by ogrinspect.',
  84. 'from django.contrib.gis.db import models',
  85. '',
  86. 'class City(models.Model):',
  87. ' name = models.CharField(max_length=80)',
  88. ' population = models.{}()'.format('BigIntegerField' if GDAL_VERSION >= (2, 0) else 'FloatField'),
  89. ' density = models.FloatField()',
  90. ' created = models.DateField()',
  91. ' geom = models.PointField(srid=-1)',
  92. ]
  93. self.assertEqual(model_def, '\n'.join(expected))
  94. def test_time_field(self):
  95. # Getting the database identifier used by OGR, if None returned
  96. # GDAL does not have the support compiled in.
  97. ogr_db = get_ogr_db_string()
  98. if not ogr_db:
  99. self.skipTest("Unable to setup an OGR connection to your database")
  100. try:
  101. # Writing shapefiles via GDAL currently does not support writing OGRTime
  102. # fields, so we need to actually use a database
  103. model_def = ogrinspect(ogr_db, 'Measurement',
  104. layer_key=AllOGRFields._meta.db_table,
  105. decimal=['f_decimal'])
  106. except GDALException:
  107. self.skipTest("Unable to setup an OGR connection to your database")
  108. self.assertTrue(model_def.startswith(
  109. '# This is an auto-generated Django model module created by ogrinspect.\n'
  110. 'from django.contrib.gis.db import models\n'
  111. '\n'
  112. 'class Measurement(models.Model):\n'
  113. ))
  114. # The ordering of model fields might vary depending on several factors (version of GDAL, etc.)
  115. self.assertIn(' f_decimal = models.DecimalField(max_digits=0, decimal_places=0)', model_def)
  116. self.assertIn(' f_int = models.IntegerField()', model_def)
  117. self.assertIn(' f_datetime = models.DateTimeField()', model_def)
  118. self.assertIn(' f_time = models.TimeField()', model_def)
  119. self.assertIn(' f_float = models.FloatField()', model_def)
  120. self.assertIn(' f_char = models.CharField(max_length=10)', model_def)
  121. self.assertIn(' f_date = models.DateField()', model_def)
  122. # Some backends may have srid=-1
  123. self.assertIsNotNone(re.search(r' geom = models.PolygonField\(([^\)])*\)', model_def))
  124. def test_management_command(self):
  125. shp_file = os.path.join(TEST_DATA, 'cities', 'cities.shp')
  126. out = StringIO()
  127. call_command('ogrinspect', shp_file, 'City', stdout=out)
  128. output = out.getvalue()
  129. self.assertIn('class City(models.Model):', output)
  130. def get_ogr_db_string():
  131. """
  132. Construct the DB string that GDAL will use to inspect the database.
  133. GDAL will create its own connection to the database, so we re-use the
  134. connection settings from the Django test.
  135. """
  136. db = connections.databases['default']
  137. # Map from the django backend into the OGR driver name and database identifier
  138. # http://www.gdal.org/ogr/ogr_formats.html
  139. #
  140. # TODO: Support Oracle (OCI).
  141. drivers = {
  142. 'django.contrib.gis.db.backends.postgis': ('PostgreSQL', "PG:dbname='%(db_name)s'", ' '),
  143. 'django.contrib.gis.db.backends.mysql': ('MySQL', 'MYSQL:"%(db_name)s"', ','),
  144. 'django.contrib.gis.db.backends.spatialite': ('SQLite', '%(db_name)s', '')
  145. }
  146. db_engine = db['ENGINE']
  147. if db_engine not in drivers:
  148. return None
  149. drv_name, db_str, param_sep = drivers[db_engine]
  150. # Ensure that GDAL library has driver support for the database.
  151. try:
  152. Driver(drv_name)
  153. except GDALException:
  154. return None
  155. # SQLite/Spatialite in-memory databases
  156. if db['NAME'] == ":memory:":
  157. return None
  158. # Build the params of the OGR database connection string
  159. params = [db_str % {'db_name': db['NAME']}]
  160. def add(key, template):
  161. value = db.get(key, None)
  162. # Don't add the parameter if it is not in django's settings
  163. if value:
  164. params.append(template % value)
  165. add('HOST', "host='%s'")
  166. add('PORT', "port='%s'")
  167. add('USER', "user='%s'")
  168. add('PASSWORD', "password='%s'")
  169. return param_sep.join(params)