utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. from contextlib import contextmanager
  2. import logging
  3. import re
  4. import sys
  5. from threading import local
  6. import time
  7. from unittest import skipUnless
  8. import warnings
  9. from functools import wraps
  10. from xml.dom.minidom import parseString, Node
  11. from django.apps import app_cache
  12. from django.conf import settings, UserSettingsHolder
  13. from django.core import mail
  14. from django.core.signals import request_started
  15. from django.db import reset_queries
  16. from django.http import request
  17. from django.template import Template, loader, TemplateDoesNotExist
  18. from django.template.loaders import cached
  19. from django.test.signals import template_rendered, setting_changed
  20. from django.utils.encoding import force_str
  21. from django.utils import six
  22. from django.utils.translation import deactivate
  23. __all__ = (
  24. 'Approximate', 'ContextList', 'get_runner',
  25. 'modify_settings', 'override_settings',
  26. 'requires_tz_support',
  27. 'setup_test_environment', 'teardown_test_environment',
  28. )
  29. RESTORE_LOADERS_ATTR = '_original_template_source_loaders'
  30. TZ_SUPPORT = hasattr(time, 'tzset')
  31. class Approximate(object):
  32. def __init__(self, val, places=7):
  33. self.val = val
  34. self.places = places
  35. def __repr__(self):
  36. return repr(self.val)
  37. def __eq__(self, other):
  38. if self.val == other:
  39. return True
  40. return round(abs(self.val - other), self.places) == 0
  41. class ContextList(list):
  42. """A wrapper that provides direct key access to context items contained
  43. in a list of context objects.
  44. """
  45. def __getitem__(self, key):
  46. if isinstance(key, six.string_types):
  47. for subcontext in self:
  48. if key in subcontext:
  49. return subcontext[key]
  50. raise KeyError(key)
  51. else:
  52. return super(ContextList, self).__getitem__(key)
  53. def __contains__(self, key):
  54. try:
  55. self[key]
  56. except KeyError:
  57. return False
  58. return True
  59. def keys(self):
  60. """
  61. Flattened keys of subcontexts.
  62. """
  63. keys = set()
  64. for subcontext in self:
  65. for dict in subcontext:
  66. keys |= set(dict.keys())
  67. return keys
  68. def instrumented_test_render(self, context):
  69. """
  70. An instrumented Template render method, providing a signal
  71. that can be intercepted by the test system Client
  72. """
  73. template_rendered.send(sender=self, template=self, context=context)
  74. return self.nodelist.render(context)
  75. def setup_test_environment():
  76. """Perform any global pre-test setup. This involves:
  77. - Installing the instrumented test renderer
  78. - Set the email backend to the locmem email backend.
  79. - Setting the active locale to match the LANGUAGE_CODE setting.
  80. """
  81. Template._original_render = Template._render
  82. Template._render = instrumented_test_render
  83. # Storing previous values in the settings module itself is problematic.
  84. # Store them in arbitrary (but related) modules instead. See #20636.
  85. mail._original_email_backend = settings.EMAIL_BACKEND
  86. settings.EMAIL_BACKEND = 'django.core.mail.backends.locmem.EmailBackend'
  87. request._original_allowed_hosts = settings.ALLOWED_HOSTS
  88. settings.ALLOWED_HOSTS = ['*']
  89. mail.outbox = []
  90. deactivate()
  91. def teardown_test_environment():
  92. """Perform any global post-test teardown. This involves:
  93. - Restoring the original test renderer
  94. - Restoring the email sending functions
  95. """
  96. Template._render = Template._original_render
  97. del Template._original_render
  98. settings.EMAIL_BACKEND = mail._original_email_backend
  99. del mail._original_email_backend
  100. settings.ALLOWED_HOSTS = request._original_allowed_hosts
  101. del request._original_allowed_hosts
  102. del mail.outbox
  103. def get_runner(settings, test_runner_class=None):
  104. if not test_runner_class:
  105. test_runner_class = settings.TEST_RUNNER
  106. test_path = test_runner_class.split('.')
  107. # Allow for Python 2.5 relative paths
  108. if len(test_path) > 1:
  109. test_module_name = '.'.join(test_path[:-1])
  110. else:
  111. test_module_name = '.'
  112. test_module = __import__(test_module_name, {}, {}, force_str(test_path[-1]))
  113. test_runner = getattr(test_module, test_path[-1])
  114. return test_runner
  115. def setup_test_template_loader(templates_dict, use_cached_loader=False):
  116. """
  117. Changes Django to only find templates from within a dictionary (where each
  118. key is the template name and each value is the corresponding template
  119. content to return).
  120. Use meth:`restore_template_loaders` to restore the original loaders.
  121. """
  122. if hasattr(loader, RESTORE_LOADERS_ATTR):
  123. raise Exception("loader.%s already exists" % RESTORE_LOADERS_ATTR)
  124. def test_template_loader(template_name, template_dirs=None):
  125. "A custom template loader that loads templates from a dictionary."
  126. try:
  127. return (templates_dict[template_name], "test:%s" % template_name)
  128. except KeyError:
  129. raise TemplateDoesNotExist(template_name)
  130. if use_cached_loader:
  131. template_loader = cached.Loader(('test_template_loader',))
  132. template_loader._cached_loaders = (test_template_loader,)
  133. else:
  134. template_loader = test_template_loader
  135. setattr(loader, RESTORE_LOADERS_ATTR, loader.template_source_loaders)
  136. loader.template_source_loaders = (template_loader,)
  137. return template_loader
  138. def restore_template_loaders():
  139. """
  140. Restores the original template loaders after
  141. :meth:`setup_test_template_loader` has been run.
  142. """
  143. loader.template_source_loaders = getattr(loader, RESTORE_LOADERS_ATTR)
  144. delattr(loader, RESTORE_LOADERS_ATTR)
  145. class override_settings(object):
  146. """
  147. Acts as either a decorator, or a context manager. If it's a decorator it
  148. takes a function and returns a wrapped function. If it's a contextmanager
  149. it's used with the ``with`` statement. In either event entering/exiting
  150. are called before and after, respectively, the function/block is executed.
  151. """
  152. def __init__(self, **kwargs):
  153. self.options = kwargs
  154. def __enter__(self):
  155. self.enable()
  156. def __exit__(self, exc_type, exc_value, traceback):
  157. self.disable()
  158. def __call__(self, test_func):
  159. from django.test import SimpleTestCase
  160. if isinstance(test_func, type):
  161. if not issubclass(test_func, SimpleTestCase):
  162. raise Exception(
  163. "Only subclasses of Django SimpleTestCase can be decorated "
  164. "with override_settings")
  165. self.save_options(test_func)
  166. return test_func
  167. else:
  168. @wraps(test_func)
  169. def inner(*args, **kwargs):
  170. with self:
  171. return test_func(*args, **kwargs)
  172. return inner
  173. def save_options(self, test_func):
  174. if test_func._overridden_settings is None:
  175. test_func._overridden_settings = self.options
  176. else:
  177. # Duplicate dict to prevent subclasses from altering their parent.
  178. test_func._overridden_settings = dict(
  179. test_func._overridden_settings, **self.options)
  180. def enable(self):
  181. override = UserSettingsHolder(settings._wrapped)
  182. for key, new_value in self.options.items():
  183. setattr(override, key, new_value)
  184. self.wrapped = settings._wrapped
  185. settings._wrapped = override
  186. if 'INSTALLED_APPS' in self.options:
  187. app_cache.set_installed_apps(settings.INSTALLED_APPS)
  188. for key, new_value in self.options.items():
  189. setting_changed.send(sender=settings._wrapped.__class__,
  190. setting=key, value=new_value, enter=True)
  191. def disable(self):
  192. settings._wrapped = self.wrapped
  193. del self.wrapped
  194. if 'INSTALLED_APPS' in self.options:
  195. app_cache.unset_installed_apps()
  196. for key in self.options:
  197. new_value = getattr(settings, key, None)
  198. setting_changed.send(sender=settings._wrapped.__class__,
  199. setting=key, value=new_value, enter=False)
  200. class modify_settings(override_settings):
  201. """
  202. Like override_settings, but makes it possible to append, prepend or remove
  203. items instead of redefining the entire list.
  204. """
  205. def __init__(self, *args, **kwargs):
  206. if args:
  207. # Hack used when instaciating from SimpleTestCase._pre_setup.
  208. assert not kwargs
  209. self.operations = args[0]
  210. else:
  211. assert not args
  212. self.operations = list(kwargs.items())
  213. def save_options(self, test_func):
  214. if test_func._modified_settings is None:
  215. test_func._modified_settings = self.operations
  216. else:
  217. # Duplicate list to prevent subclasses from altering their parent.
  218. test_func._modified_settings = list(
  219. test_func._modified_settings) + self.operations
  220. def enable(self):
  221. self.options = {}
  222. for name, operations in self.operations:
  223. try:
  224. # When called from SimpleTestCase._pre_setup, values may be
  225. # overridden several times; cumulate changes.
  226. value = self.options[name]
  227. except KeyError:
  228. value = list(getattr(settings, name, []))
  229. for action, items in operations.items():
  230. # items my be a single value or an iterable.
  231. if isinstance(items, six.string_types):
  232. items = [items]
  233. if action == 'append':
  234. value = value + [item for item in items if item not in value]
  235. elif action == 'prepend':
  236. value = [item for item in items if item not in value] + value
  237. elif action == 'remove':
  238. value = [item for item in value if item not in items]
  239. else:
  240. raise ValueError("Unsupported action: %s" % action)
  241. self.options[name] = value
  242. super(modify_settings, self).enable()
  243. def compare_xml(want, got):
  244. """Tries to do a 'xml-comparison' of want and got. Plain string
  245. comparison doesn't always work because, for example, attribute
  246. ordering should not be important. Comment nodes are not considered in the
  247. comparison.
  248. Based on http://codespeak.net/svn/lxml/trunk/src/lxml/doctestcompare.py
  249. """
  250. _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
  251. def norm_whitespace(v):
  252. return _norm_whitespace_re.sub(' ', v)
  253. def child_text(element):
  254. return ''.join([c.data for c in element.childNodes
  255. if c.nodeType == Node.TEXT_NODE])
  256. def children(element):
  257. return [c for c in element.childNodes
  258. if c.nodeType == Node.ELEMENT_NODE]
  259. def norm_child_text(element):
  260. return norm_whitespace(child_text(element))
  261. def attrs_dict(element):
  262. return dict(element.attributes.items())
  263. def check_element(want_element, got_element):
  264. if want_element.tagName != got_element.tagName:
  265. return False
  266. if norm_child_text(want_element) != norm_child_text(got_element):
  267. return False
  268. if attrs_dict(want_element) != attrs_dict(got_element):
  269. return False
  270. want_children = children(want_element)
  271. got_children = children(got_element)
  272. if len(want_children) != len(got_children):
  273. return False
  274. for want, got in zip(want_children, got_children):
  275. if not check_element(want, got):
  276. return False
  277. return True
  278. def first_node(document):
  279. for node in document.childNodes:
  280. if node.nodeType != Node.COMMENT_NODE:
  281. return node
  282. want, got = strip_quotes(want, got)
  283. want = want.replace('\\n', '\n')
  284. got = got.replace('\\n', '\n')
  285. # If the string is not a complete xml document, we may need to add a
  286. # root element. This allow us to compare fragments, like "<foo/><bar/>"
  287. if not want.startswith('<?xml'):
  288. wrapper = '<root>%s</root>'
  289. want = wrapper % want
  290. got = wrapper % got
  291. # Parse the want and got strings, and compare the parsings.
  292. want_root = first_node(parseString(want))
  293. got_root = first_node(parseString(got))
  294. return check_element(want_root, got_root)
  295. def strip_quotes(want, got):
  296. """
  297. Strip quotes of doctests output values:
  298. >>> strip_quotes("'foo'")
  299. "foo"
  300. >>> strip_quotes('"foo"')
  301. "foo"
  302. """
  303. def is_quoted_string(s):
  304. s = s.strip()
  305. return (len(s) >= 2
  306. and s[0] == s[-1]
  307. and s[0] in ('"', "'"))
  308. def is_quoted_unicode(s):
  309. s = s.strip()
  310. return (len(s) >= 3
  311. and s[0] == 'u'
  312. and s[1] == s[-1]
  313. and s[1] in ('"', "'"))
  314. if is_quoted_string(want) and is_quoted_string(got):
  315. want = want.strip()[1:-1]
  316. got = got.strip()[1:-1]
  317. elif is_quoted_unicode(want) and is_quoted_unicode(got):
  318. want = want.strip()[2:-1]
  319. got = got.strip()[2:-1]
  320. return want, got
  321. def str_prefix(s):
  322. return s % {'_': '' if six.PY3 else 'u'}
  323. class CaptureQueriesContext(object):
  324. """
  325. Context manager that captures queries executed by the specified connection.
  326. """
  327. def __init__(self, connection):
  328. self.connection = connection
  329. def __iter__(self):
  330. return iter(self.captured_queries)
  331. def __getitem__(self, index):
  332. return self.captured_queries[index]
  333. def __len__(self):
  334. return len(self.captured_queries)
  335. @property
  336. def captured_queries(self):
  337. return self.connection.queries[self.initial_queries:self.final_queries]
  338. def __enter__(self):
  339. self.use_debug_cursor = self.connection.use_debug_cursor
  340. self.connection.use_debug_cursor = True
  341. self.initial_queries = len(self.connection.queries)
  342. self.final_queries = None
  343. request_started.disconnect(reset_queries)
  344. return self
  345. def __exit__(self, exc_type, exc_value, traceback):
  346. self.connection.use_debug_cursor = self.use_debug_cursor
  347. request_started.connect(reset_queries)
  348. if exc_type is not None:
  349. return
  350. self.final_queries = len(self.connection.queries)
  351. class IgnoreDeprecationWarningsMixin(object):
  352. warning_classes = [DeprecationWarning]
  353. def setUp(self):
  354. super(IgnoreDeprecationWarningsMixin, self).setUp()
  355. self.catch_warnings = warnings.catch_warnings()
  356. self.catch_warnings.__enter__()
  357. for warning_class in self.warning_classes:
  358. warnings.filterwarnings("ignore", category=warning_class)
  359. def tearDown(self):
  360. self.catch_warnings.__exit__(*sys.exc_info())
  361. super(IgnoreDeprecationWarningsMixin, self).tearDown()
  362. class IgnorePendingDeprecationWarningsMixin(IgnoreDeprecationWarningsMixin):
  363. warning_classes = [PendingDeprecationWarning]
  364. class IgnoreAllDeprecationWarningsMixin(IgnoreDeprecationWarningsMixin):
  365. warning_classes = [PendingDeprecationWarning, DeprecationWarning]
  366. @contextmanager
  367. def patch_logger(logger_name, log_level):
  368. """
  369. Context manager that takes a named logger and the logging level
  370. and provides a simple mock-like list of messages received
  371. """
  372. calls = []
  373. def replacement(msg, *args, **kwargs):
  374. calls.append(msg % args)
  375. logger = logging.getLogger(logger_name)
  376. orig = getattr(logger, log_level)
  377. setattr(logger, log_level, replacement)
  378. try:
  379. yield calls
  380. finally:
  381. setattr(logger, log_level, orig)
  382. class TransRealMixin(object):
  383. """This is the only way to reset the translation machinery. Otherwise
  384. the test suite occasionally fails because of global state pollution
  385. between tests."""
  386. def flush_caches(self):
  387. from django.utils.translation import trans_real
  388. trans_real._translations = {}
  389. trans_real._active = local()
  390. trans_real._default = None
  391. trans_real.check_for_language.cache_clear()
  392. def tearDown(self):
  393. self.flush_caches()
  394. super(TransRealMixin, self).tearDown()
  395. # On OSes that don't provide tzset (Windows), we can't set the timezone
  396. # in which the program runs. As a consequence, we must skip tests that
  397. # don't enforce a specific timezone (with timezone.override or equivalent),
  398. # or attempt to interpret naive datetimes in the default timezone.
  399. requires_tz_support = skipUnless(TZ_SUPPORT,
  400. "This test relies on the ability to run a program in an arbitrary "
  401. "time zone, but your operating system isn't able to do that.")