tests.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. import os
  2. import re
  3. from io import StringIO
  4. from django.contrib.gis.gdal import GDAL_VERSION, Driver, GDALException
  5. from django.contrib.gis.utils.ogrinspect import ogrinspect
  6. from django.core.management import call_command
  7. from django.db import connection, connections
  8. from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
  9. from django.test.utils import modify_settings
  10. from ..test_data import TEST_DATA
  11. from .models import AllOGRFields
  12. class InspectDbTests(TestCase):
  13. def test_geom_columns(self):
  14. """
  15. Test the geo-enabled inspectdb command.
  16. """
  17. out = StringIO()
  18. call_command(
  19. "inspectdb",
  20. table_name_filter=lambda tn: tn == "inspectapp_allogrfields",
  21. stdout=out,
  22. )
  23. output = out.getvalue()
  24. if connection.features.supports_geometry_field_introspection:
  25. self.assertIn("geom = models.PolygonField()", output)
  26. self.assertIn("point = models.PointField()", output)
  27. else:
  28. self.assertIn("geom = models.GeometryField(", output)
  29. self.assertIn("point = models.GeometryField(", output)
  30. @skipUnlessDBFeature("supports_3d_storage")
  31. def test_3d_columns(self):
  32. out = StringIO()
  33. call_command(
  34. "inspectdb",
  35. table_name_filter=lambda tn: tn == "inspectapp_fields3d",
  36. stdout=out,
  37. )
  38. output = out.getvalue()
  39. if connection.features.supports_geometry_field_introspection:
  40. self.assertIn("point = models.PointField(dim=3)", output)
  41. if connection.features.supports_geography:
  42. self.assertIn(
  43. "pointg = models.PointField(geography=True, dim=3)", output
  44. )
  45. else:
  46. self.assertIn("pointg = models.PointField(dim=3)", output)
  47. self.assertIn("line = models.LineStringField(dim=3)", output)
  48. self.assertIn("poly = models.PolygonField(dim=3)", output)
  49. else:
  50. self.assertIn("point = models.GeometryField(", output)
  51. self.assertIn("pointg = models.GeometryField(", output)
  52. self.assertIn("line = models.GeometryField(", output)
  53. self.assertIn("poly = models.GeometryField(", output)
  54. @modify_settings(
  55. INSTALLED_APPS={"append": "django.contrib.gis"},
  56. )
  57. class OGRInspectTest(SimpleTestCase):
  58. maxDiff = 1024
  59. def test_poly(self):
  60. shp_file = os.path.join(TEST_DATA, "test_poly", "test_poly.shp")
  61. model_def = ogrinspect(shp_file, "MyModel")
  62. expected = [
  63. "# This is an auto-generated Django model module created by ogrinspect.",
  64. "from django.contrib.gis.db import models",
  65. "",
  66. "",
  67. "class MyModel(models.Model):",
  68. " float = models.FloatField()",
  69. " int = models.BigIntegerField()",
  70. " str = models.CharField(max_length=80)",
  71. " geom = models.PolygonField()",
  72. ]
  73. self.assertEqual(model_def, "\n".join(expected))
  74. def test_poly_multi(self):
  75. shp_file = os.path.join(TEST_DATA, "test_poly", "test_poly.shp")
  76. model_def = ogrinspect(shp_file, "MyModel", multi_geom=True)
  77. self.assertIn("geom = models.MultiPolygonField()", model_def)
  78. # Same test with a 25D-type geometry field
  79. shp_file = os.path.join(TEST_DATA, "gas_lines", "gas_leitung.shp")
  80. model_def = ogrinspect(shp_file, "MyModel", multi_geom=True)
  81. srid = "-1" if GDAL_VERSION < (2, 3) else "31253"
  82. self.assertIn("geom = models.MultiLineStringField(srid=%s)" % srid, model_def)
  83. def test_date_field(self):
  84. shp_file = os.path.join(TEST_DATA, "cities", "cities.shp")
  85. model_def = ogrinspect(shp_file, "City")
  86. expected = [
  87. "# This is an auto-generated Django model module created by ogrinspect.",
  88. "from django.contrib.gis.db import models",
  89. "",
  90. "",
  91. "class City(models.Model):",
  92. " name = models.CharField(max_length=80)",
  93. " population = models.BigIntegerField()",
  94. " density = models.FloatField()",
  95. " created = models.DateField()",
  96. " geom = models.PointField()",
  97. ]
  98. self.assertEqual(model_def, "\n".join(expected))
  99. def test_time_field(self):
  100. # Getting the database identifier used by OGR, if None returned
  101. # GDAL does not have the support compiled in.
  102. ogr_db = get_ogr_db_string()
  103. if not ogr_db:
  104. self.skipTest("Unable to setup an OGR connection to your database")
  105. try:
  106. # Writing shapefiles via GDAL currently does not support writing OGRTime
  107. # fields, so we need to actually use a database
  108. model_def = ogrinspect(
  109. ogr_db,
  110. "Measurement",
  111. layer_key=AllOGRFields._meta.db_table,
  112. decimal=["f_decimal"],
  113. )
  114. except GDALException:
  115. self.skipTest("Unable to setup an OGR connection to your database")
  116. self.assertTrue(
  117. model_def.startswith(
  118. "# This is an auto-generated Django model module created by "
  119. "ogrinspect.\n"
  120. "from django.contrib.gis.db import models\n"
  121. "\n"
  122. "\n"
  123. "class Measurement(models.Model):\n"
  124. )
  125. )
  126. # The ordering of model fields might vary depending on several factors
  127. # (version of GDAL, etc.).
  128. if connection.vendor == "sqlite" and GDAL_VERSION < (3, 4):
  129. # SpatiaLite introspection is somewhat lacking on GDAL < 3.4 (#29461).
  130. self.assertIn(" f_decimal = models.CharField(max_length=0)", model_def)
  131. else:
  132. self.assertIn(
  133. " f_decimal = models.DecimalField(max_digits=0, decimal_places=0)",
  134. model_def,
  135. )
  136. self.assertIn(" f_int = models.IntegerField()", model_def)
  137. if not connection.ops.mariadb:
  138. # Probably a bug between GDAL and MariaDB on time fields.
  139. self.assertIn(" f_datetime = models.DateTimeField()", model_def)
  140. self.assertIn(" f_time = models.TimeField()", model_def)
  141. if connection.vendor == "sqlite" and GDAL_VERSION < (3, 4):
  142. self.assertIn(" f_float = models.CharField(max_length=0)", model_def)
  143. else:
  144. self.assertIn(" f_float = models.FloatField()", model_def)
  145. max_length = 0 if connection.vendor == "sqlite" else 10
  146. self.assertIn(
  147. " f_char = models.CharField(max_length=%s)" % max_length, model_def
  148. )
  149. self.assertIn(" f_date = models.DateField()", model_def)
  150. # Some backends may have srid=-1
  151. self.assertIsNotNone(
  152. re.search(r" geom = models.PolygonField\(([^\)])*\)", model_def)
  153. )
  154. def test_management_command(self):
  155. shp_file = os.path.join(TEST_DATA, "cities", "cities.shp")
  156. out = StringIO()
  157. call_command("ogrinspect", shp_file, "City", stdout=out)
  158. output = out.getvalue()
  159. self.assertIn("class City(models.Model):", output)
  160. def test_mapping_option(self):
  161. expected = (
  162. " geom = models.PointField()\n"
  163. "\n"
  164. "\n"
  165. "# Auto-generated `LayerMapping` dictionary for City model\n"
  166. "city_mapping = {\n"
  167. " 'name': 'Name',\n"
  168. " 'population': 'Population',\n"
  169. " 'density': 'Density',\n"
  170. " 'created': 'Created',\n"
  171. " 'geom': 'POINT',\n"
  172. "}\n"
  173. )
  174. shp_file = os.path.join(TEST_DATA, "cities", "cities.shp")
  175. out = StringIO()
  176. call_command("ogrinspect", shp_file, "--mapping", "City", stdout=out)
  177. self.assertIn(expected, out.getvalue())
  178. def get_ogr_db_string():
  179. """
  180. Construct the DB string that GDAL will use to inspect the database.
  181. GDAL will create its own connection to the database, so we re-use the
  182. connection settings from the Django test.
  183. """
  184. db = connections.databases["default"]
  185. # Map from the django backend into the OGR driver name and database identifier
  186. # https://gdal.org/drivers/vector/
  187. #
  188. # TODO: Support Oracle (OCI).
  189. drivers = {
  190. "django.contrib.gis.db.backends.postgis": (
  191. "PostgreSQL",
  192. "PG:dbname='%(db_name)s'",
  193. " ",
  194. ),
  195. "django.contrib.gis.db.backends.mysql": ("MySQL", 'MYSQL:"%(db_name)s"', ","),
  196. "django.contrib.gis.db.backends.spatialite": ("SQLite", "%(db_name)s", ""),
  197. }
  198. db_engine = db["ENGINE"]
  199. if db_engine not in drivers:
  200. return None
  201. drv_name, db_str, param_sep = drivers[db_engine]
  202. # Ensure that GDAL library has driver support for the database.
  203. try:
  204. Driver(drv_name)
  205. except GDALException:
  206. return None
  207. # SQLite/SpatiaLite in-memory databases
  208. if db["NAME"] == ":memory:":
  209. return None
  210. # Build the params of the OGR database connection string
  211. params = [db_str % {"db_name": db["NAME"]}]
  212. def add(key, template):
  213. value = db.get(key, None)
  214. # Don't add the parameter if it is not in django's settings
  215. if value:
  216. params.append(template % value)
  217. add("HOST", "host='%s'")
  218. add("PORT", "port='%s'")
  219. add("USER", "user='%s'")
  220. add("PASSWORD", "password='%s'")
  221. return param_sep.join(params)