utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. from contextlib import contextmanager
  2. import logging
  3. import re
  4. import sys
  5. import time
  6. from unittest import skipUnless
  7. import warnings
  8. from functools import wraps
  9. from xml.dom.minidom import parseString, Node
  10. from django.apps import apps
  11. from django.conf import settings, UserSettingsHolder
  12. from django.core import mail
  13. from django.core.signals import request_started
  14. from django.db import reset_queries
  15. from django.http import request
  16. from django.template import Template, loader, TemplateDoesNotExist
  17. from django.template.loaders import cached
  18. from django.test.signals import template_rendered, setting_changed
  19. from django.utils import six
  20. from django.utils.deprecation import RemovedInDjango18Warning, RemovedInDjango19Warning
  21. from django.utils.encoding import force_str
  22. from django.utils.translation import deactivate
  23. __all__ = (
  24. 'Approximate', 'ContextList', 'get_runner',
  25. 'modify_settings', 'override_settings',
  26. 'requires_tz_support',
  27. 'setup_test_environment', 'teardown_test_environment',
  28. )
  29. RESTORE_LOADERS_ATTR = '_original_template_source_loaders'
  30. TZ_SUPPORT = hasattr(time, 'tzset')
  31. class Approximate(object):
  32. def __init__(self, val, places=7):
  33. self.val = val
  34. self.places = places
  35. def __repr__(self):
  36. return repr(self.val)
  37. def __eq__(self, other):
  38. if self.val == other:
  39. return True
  40. return round(abs(self.val - other), self.places) == 0
  41. class ContextList(list):
  42. """A wrapper that provides direct key access to context items contained
  43. in a list of context objects.
  44. """
  45. def __getitem__(self, key):
  46. if isinstance(key, six.string_types):
  47. for subcontext in self:
  48. if key in subcontext:
  49. return subcontext[key]
  50. raise KeyError(key)
  51. else:
  52. return super(ContextList, self).__getitem__(key)
  53. def __contains__(self, key):
  54. try:
  55. self[key]
  56. except KeyError:
  57. return False
  58. return True
  59. def keys(self):
  60. """
  61. Flattened keys of subcontexts.
  62. """
  63. keys = set()
  64. for subcontext in self:
  65. for dict in subcontext:
  66. keys |= set(dict.keys())
  67. return keys
  68. def instrumented_test_render(self, context):
  69. """
  70. An instrumented Template render method, providing a signal
  71. that can be intercepted by the test system Client
  72. """
  73. template_rendered.send(sender=self, template=self, context=context)
  74. return self.nodelist.render(context)
  75. def setup_test_environment():
  76. """Perform any global pre-test setup. This involves:
  77. - Installing the instrumented test renderer
  78. - Set the email backend to the locmem email backend.
  79. - Setting the active locale to match the LANGUAGE_CODE setting.
  80. """
  81. Template._original_render = Template._render
  82. Template._render = instrumented_test_render
  83. # Storing previous values in the settings module itself is problematic.
  84. # Store them in arbitrary (but related) modules instead. See #20636.
  85. mail._original_email_backend = settings.EMAIL_BACKEND
  86. settings.EMAIL_BACKEND = 'django.core.mail.backends.locmem.EmailBackend'
  87. request._original_allowed_hosts = settings.ALLOWED_HOSTS
  88. settings.ALLOWED_HOSTS = ['*']
  89. mail.outbox = []
  90. deactivate()
  91. def teardown_test_environment():
  92. """Perform any global post-test teardown. This involves:
  93. - Restoring the original test renderer
  94. - Restoring the email sending functions
  95. """
  96. Template._render = Template._original_render
  97. del Template._original_render
  98. settings.EMAIL_BACKEND = mail._original_email_backend
  99. del mail._original_email_backend
  100. settings.ALLOWED_HOSTS = request._original_allowed_hosts
  101. del request._original_allowed_hosts
  102. del mail.outbox
  103. def get_runner(settings, test_runner_class=None):
  104. if not test_runner_class:
  105. test_runner_class = settings.TEST_RUNNER
  106. test_path = test_runner_class.split('.')
  107. # Allow for Python 2.5 relative paths
  108. if len(test_path) > 1:
  109. test_module_name = '.'.join(test_path[:-1])
  110. else:
  111. test_module_name = '.'
  112. test_module = __import__(test_module_name, {}, {}, force_str(test_path[-1]))
  113. test_runner = getattr(test_module, test_path[-1])
  114. return test_runner
  115. def setup_test_template_loader(templates_dict, use_cached_loader=False):
  116. """
  117. Changes Django to only find templates from within a dictionary (where each
  118. key is the template name and each value is the corresponding template
  119. content to return).
  120. Use meth:`restore_template_loaders` to restore the original loaders.
  121. """
  122. if hasattr(loader, RESTORE_LOADERS_ATTR):
  123. raise Exception("loader.%s already exists" % RESTORE_LOADERS_ATTR)
  124. def test_template_loader(template_name, template_dirs=None):
  125. "A custom template loader that loads templates from a dictionary."
  126. try:
  127. return (templates_dict[template_name], "test:%s" % template_name)
  128. except KeyError:
  129. raise TemplateDoesNotExist(template_name)
  130. if use_cached_loader:
  131. template_loader = cached.Loader(('test_template_loader',))
  132. template_loader._cached_loaders = (test_template_loader,)
  133. else:
  134. template_loader = test_template_loader
  135. setattr(loader, RESTORE_LOADERS_ATTR, loader.template_source_loaders)
  136. loader.template_source_loaders = (template_loader,)
  137. return template_loader
  138. def restore_template_loaders():
  139. """
  140. Restores the original template loaders after
  141. :meth:`setup_test_template_loader` has been run.
  142. """
  143. loader.template_source_loaders = getattr(loader, RESTORE_LOADERS_ATTR)
  144. delattr(loader, RESTORE_LOADERS_ATTR)
  145. class override_settings(object):
  146. """
  147. Acts as either a decorator, or a context manager. If it's a decorator it
  148. takes a function and returns a wrapped function. If it's a contextmanager
  149. it's used with the ``with`` statement. In either event entering/exiting
  150. are called before and after, respectively, the function/block is executed.
  151. """
  152. def __init__(self, **kwargs):
  153. self.options = kwargs
  154. def __enter__(self):
  155. self.enable()
  156. def __exit__(self, exc_type, exc_value, traceback):
  157. self.disable()
  158. def __call__(self, test_func):
  159. from django.test import SimpleTestCase
  160. if isinstance(test_func, type):
  161. if not issubclass(test_func, SimpleTestCase):
  162. raise Exception(
  163. "Only subclasses of Django SimpleTestCase can be decorated "
  164. "with override_settings")
  165. self.save_options(test_func)
  166. return test_func
  167. else:
  168. @wraps(test_func)
  169. def inner(*args, **kwargs):
  170. with self:
  171. return test_func(*args, **kwargs)
  172. return inner
  173. def save_options(self, test_func):
  174. if test_func._overridden_settings is None:
  175. test_func._overridden_settings = self.options
  176. else:
  177. # Duplicate dict to prevent subclasses from altering their parent.
  178. test_func._overridden_settings = dict(
  179. test_func._overridden_settings, **self.options)
  180. def enable(self):
  181. # Keep this code at the beginning to leave the settings unchanged
  182. # in case it raises an exception because INSTALLED_APPS is invalid.
  183. if 'INSTALLED_APPS' in self.options:
  184. try:
  185. apps.set_installed_apps(self.options['INSTALLED_APPS'])
  186. except Exception:
  187. apps.unset_installed_apps()
  188. raise
  189. override = UserSettingsHolder(settings._wrapped)
  190. for key, new_value in self.options.items():
  191. setattr(override, key, new_value)
  192. self.wrapped = settings._wrapped
  193. settings._wrapped = override
  194. for key, new_value in self.options.items():
  195. setting_changed.send(sender=settings._wrapped.__class__,
  196. setting=key, value=new_value, enter=True)
  197. def disable(self):
  198. if 'INSTALLED_APPS' in self.options:
  199. apps.unset_installed_apps()
  200. settings._wrapped = self.wrapped
  201. del self.wrapped
  202. for key in self.options:
  203. new_value = getattr(settings, key, None)
  204. setting_changed.send(sender=settings._wrapped.__class__,
  205. setting=key, value=new_value, enter=False)
  206. class modify_settings(override_settings):
  207. """
  208. Like override_settings, but makes it possible to append, prepend or remove
  209. items instead of redefining the entire list.
  210. """
  211. def __init__(self, *args, **kwargs):
  212. if args:
  213. # Hack used when instantiating from SimpleTestCase._pre_setup.
  214. assert not kwargs
  215. self.operations = args[0]
  216. else:
  217. assert not args
  218. self.operations = list(kwargs.items())
  219. def save_options(self, test_func):
  220. if test_func._modified_settings is None:
  221. test_func._modified_settings = self.operations
  222. else:
  223. # Duplicate list to prevent subclasses from altering their parent.
  224. test_func._modified_settings = list(
  225. test_func._modified_settings) + self.operations
  226. def enable(self):
  227. self.options = {}
  228. for name, operations in self.operations:
  229. try:
  230. # When called from SimpleTestCase._pre_setup, values may be
  231. # overridden several times; cumulate changes.
  232. value = self.options[name]
  233. except KeyError:
  234. value = list(getattr(settings, name, []))
  235. for action, items in operations.items():
  236. # items my be a single value or an iterable.
  237. if isinstance(items, six.string_types):
  238. items = [items]
  239. if action == 'append':
  240. value = value + [item for item in items if item not in value]
  241. elif action == 'prepend':
  242. value = [item for item in items if item not in value] + value
  243. elif action == 'remove':
  244. value = [item for item in value if item not in items]
  245. else:
  246. raise ValueError("Unsupported action: %s" % action)
  247. self.options[name] = value
  248. super(modify_settings, self).enable()
  249. def override_system_checks(new_checks):
  250. """ Acts as a decorator. Overrides list of registered system checks.
  251. Useful when you override `INSTALLED_APPS`, e.g. if you exclude `auth` app,
  252. you also need to exclude its system checks. """
  253. from django.core.checks.registry import registry
  254. def outer(test_func):
  255. @wraps(test_func)
  256. def inner(*args, **kwargs):
  257. old_checks = registry.registered_checks
  258. registry.registered_checks = new_checks
  259. try:
  260. return test_func(*args, **kwargs)
  261. finally:
  262. registry.registered_checks = old_checks
  263. return inner
  264. return outer
  265. def compare_xml(want, got):
  266. """Tries to do a 'xml-comparison' of want and got. Plain string
  267. comparison doesn't always work because, for example, attribute
  268. ordering should not be important. Comment nodes are not considered in the
  269. comparison.
  270. Based on http://codespeak.net/svn/lxml/trunk/src/lxml/doctestcompare.py
  271. """
  272. _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
  273. def norm_whitespace(v):
  274. return _norm_whitespace_re.sub(' ', v)
  275. def child_text(element):
  276. return ''.join([c.data for c in element.childNodes
  277. if c.nodeType == Node.TEXT_NODE])
  278. def children(element):
  279. return [c for c in element.childNodes
  280. if c.nodeType == Node.ELEMENT_NODE]
  281. def norm_child_text(element):
  282. return norm_whitespace(child_text(element))
  283. def attrs_dict(element):
  284. return dict(element.attributes.items())
  285. def check_element(want_element, got_element):
  286. if want_element.tagName != got_element.tagName:
  287. return False
  288. if norm_child_text(want_element) != norm_child_text(got_element):
  289. return False
  290. if attrs_dict(want_element) != attrs_dict(got_element):
  291. return False
  292. want_children = children(want_element)
  293. got_children = children(got_element)
  294. if len(want_children) != len(got_children):
  295. return False
  296. for want, got in zip(want_children, got_children):
  297. if not check_element(want, got):
  298. return False
  299. return True
  300. def first_node(document):
  301. for node in document.childNodes:
  302. if node.nodeType != Node.COMMENT_NODE:
  303. return node
  304. want, got = strip_quotes(want, got)
  305. want = want.replace('\\n', '\n')
  306. got = got.replace('\\n', '\n')
  307. # If the string is not a complete xml document, we may need to add a
  308. # root element. This allow us to compare fragments, like "<foo/><bar/>"
  309. if not want.startswith('<?xml'):
  310. wrapper = '<root>%s</root>'
  311. want = wrapper % want
  312. got = wrapper % got
  313. # Parse the want and got strings, and compare the parsings.
  314. want_root = first_node(parseString(want))
  315. got_root = first_node(parseString(got))
  316. return check_element(want_root, got_root)
  317. def strip_quotes(want, got):
  318. """
  319. Strip quotes of doctests output values:
  320. >>> strip_quotes("'foo'")
  321. "foo"
  322. >>> strip_quotes('"foo"')
  323. "foo"
  324. """
  325. def is_quoted_string(s):
  326. s = s.strip()
  327. return (len(s) >= 2
  328. and s[0] == s[-1]
  329. and s[0] in ('"', "'"))
  330. def is_quoted_unicode(s):
  331. s = s.strip()
  332. return (len(s) >= 3
  333. and s[0] == 'u'
  334. and s[1] == s[-1]
  335. and s[1] in ('"', "'"))
  336. if is_quoted_string(want) and is_quoted_string(got):
  337. want = want.strip()[1:-1]
  338. got = got.strip()[1:-1]
  339. elif is_quoted_unicode(want) and is_quoted_unicode(got):
  340. want = want.strip()[2:-1]
  341. got = got.strip()[2:-1]
  342. return want, got
  343. def str_prefix(s):
  344. return s % {'_': '' if six.PY3 else 'u'}
  345. class CaptureQueriesContext(object):
  346. """
  347. Context manager that captures queries executed by the specified connection.
  348. """
  349. def __init__(self, connection):
  350. self.connection = connection
  351. def __iter__(self):
  352. return iter(self.captured_queries)
  353. def __getitem__(self, index):
  354. return self.captured_queries[index]
  355. def __len__(self):
  356. return len(self.captured_queries)
  357. @property
  358. def captured_queries(self):
  359. return self.connection.queries[self.initial_queries:self.final_queries]
  360. def __enter__(self):
  361. self.use_debug_cursor = self.connection.use_debug_cursor
  362. self.connection.use_debug_cursor = True
  363. self.initial_queries = len(self.connection.queries)
  364. self.final_queries = None
  365. request_started.disconnect(reset_queries)
  366. return self
  367. def __exit__(self, exc_type, exc_value, traceback):
  368. self.connection.use_debug_cursor = self.use_debug_cursor
  369. request_started.connect(reset_queries)
  370. if exc_type is not None:
  371. return
  372. self.final_queries = len(self.connection.queries)
  373. class IgnoreDeprecationWarningsMixin(object):
  374. warning_classes = [RemovedInDjango18Warning]
  375. def setUp(self):
  376. super(IgnoreDeprecationWarningsMixin, self).setUp()
  377. self.catch_warnings = warnings.catch_warnings()
  378. self.catch_warnings.__enter__()
  379. for warning_class in self.warning_classes:
  380. warnings.filterwarnings("ignore", category=warning_class)
  381. def tearDown(self):
  382. self.catch_warnings.__exit__(*sys.exc_info())
  383. super(IgnoreDeprecationWarningsMixin, self).tearDown()
  384. class IgnorePendingDeprecationWarningsMixin(IgnoreDeprecationWarningsMixin):
  385. warning_classes = [RemovedInDjango19Warning]
  386. class IgnoreAllDeprecationWarningsMixin(IgnoreDeprecationWarningsMixin):
  387. warning_classes = [RemovedInDjango19Warning, RemovedInDjango18Warning]
  388. @contextmanager
  389. def patch_logger(logger_name, log_level):
  390. """
  391. Context manager that takes a named logger and the logging level
  392. and provides a simple mock-like list of messages received
  393. """
  394. calls = []
  395. def replacement(msg, *args, **kwargs):
  396. calls.append(msg % args)
  397. logger = logging.getLogger(logger_name)
  398. orig = getattr(logger, log_level)
  399. setattr(logger, log_level, replacement)
  400. try:
  401. yield calls
  402. finally:
  403. setattr(logger, log_level, orig)
  404. # On OSes that don't provide tzset (Windows), we can't set the timezone
  405. # in which the program runs. As a consequence, we must skip tests that
  406. # don't enforce a specific timezone (with timezone.override or equivalent),
  407. # or attempt to interpret naive datetimes in the default timezone.
  408. requires_tz_support = skipUnless(TZ_SUPPORT,
  409. "This test relies on the ability to run a program in an arbitrary "
  410. "time zone, but your operating system isn't able to do that.")
  411. @contextmanager
  412. def extend_sys_path(*paths):
  413. """Context manager to temporarily add paths to sys.path."""
  414. _orig_sys_path = sys.path[:]
  415. sys.path.extend(paths)
  416. try:
  417. yield
  418. finally:
  419. sys.path = _orig_sys_path