utils.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import copy
  2. import unittest
  3. from functools import wraps
  4. from unittest import mock
  5. from django.conf import settings
  6. from django.db import DEFAULT_DB_ALIAS, connection
  7. from django.db.models import Func
  8. def skipUnlessGISLookup(*gis_lookups):
  9. """
  10. Skip a test unless a database supports all of gis_lookups.
  11. """
  12. def decorator(test_func):
  13. @wraps(test_func)
  14. def skip_wrapper(*args, **kwargs):
  15. if any(key not in connection.ops.gis_operators for key in gis_lookups):
  16. raise unittest.SkipTest(
  17. "Database doesn't support all the lookups: %s" % ", ".join(gis_lookups)
  18. )
  19. return test_func(*args, **kwargs)
  20. return skip_wrapper
  21. return decorator
  22. def no_backend(test_func, backend):
  23. "Use this decorator to disable test on specified backend."
  24. if settings.DATABASES[DEFAULT_DB_ALIAS]['ENGINE'].rsplit('.')[-1] == backend:
  25. @unittest.skip("This test is skipped on '%s' backend" % backend)
  26. def inner():
  27. pass
  28. return inner
  29. else:
  30. return test_func
  31. # Decorators to disable entire test functions for specific
  32. # spatial backends.
  33. def no_oracle(func):
  34. return no_backend(func, 'oracle')
  35. # Shortcut booleans to omit only portions of tests.
  36. _default_db = settings.DATABASES[DEFAULT_DB_ALIAS]['ENGINE'].rsplit('.')[-1]
  37. oracle = _default_db == 'oracle'
  38. postgis = _default_db == 'postgis'
  39. mysql = _default_db == 'mysql'
  40. mariadb = mysql and connection.mysql_is_mariadb
  41. spatialite = _default_db == 'spatialite'
  42. # MySQL spatial indices can't handle NULL geometries.
  43. gisfield_may_be_null = not mysql
  44. if oracle and 'gis' in settings.DATABASES[DEFAULT_DB_ALIAS]['ENGINE']:
  45. from django.contrib.gis.db.backends.oracle.models import (
  46. OracleSpatialRefSys as SpatialRefSys,
  47. )
  48. elif postgis:
  49. from django.contrib.gis.db.backends.postgis.models import (
  50. PostGISSpatialRefSys as SpatialRefSys,
  51. )
  52. elif spatialite:
  53. from django.contrib.gis.db.backends.spatialite.models import (
  54. SpatialiteSpatialRefSys as SpatialRefSys,
  55. )
  56. else:
  57. SpatialRefSys = None
  58. class FuncTestMixin:
  59. """Assert that Func expressions aren't mutated during their as_sql()."""
  60. def setUp(self):
  61. def as_sql_wrapper(original_as_sql):
  62. def inner(*args, **kwargs):
  63. func = original_as_sql.__self__
  64. # Resolve output_field before as_sql() so touching it in
  65. # as_sql() won't change __dict__.
  66. func.output_field
  67. __dict__original = copy.deepcopy(func.__dict__)
  68. result = original_as_sql(*args, **kwargs)
  69. msg = '%s Func was mutated during compilation.' % func.__class__.__name__
  70. self.assertEqual(func.__dict__, __dict__original, msg)
  71. return result
  72. return inner
  73. def __getattribute__(self, name):
  74. if name != vendor_impl:
  75. return __getattribute__original(self, name)
  76. try:
  77. as_sql = __getattribute__original(self, vendor_impl)
  78. except AttributeError:
  79. as_sql = __getattribute__original(self, 'as_sql')
  80. return as_sql_wrapper(as_sql)
  81. vendor_impl = 'as_' + connection.vendor
  82. __getattribute__original = Func.__getattribute__
  83. self.func_patcher = mock.patch.object(Func, '__getattribute__', __getattribute__)
  84. self.func_patcher.start()
  85. super().setUp()
  86. def tearDown(self):
  87. super().tearDown()
  88. self.func_patcher.stop()