case.py 21 KB


  1. from __future__ import absolute_import
  2. try:
  3. import unittest # noqa
  4. unittest.skip
  5. from unittest.util import safe_repr, unorderable_list_difference
  6. except AttributeError:
  7. import unittest2 as unittest # noqa
  8. from unittest2.util import safe_repr, unorderable_list_difference # noqa
  9. import importlib
  10. import inspect
  11. import logging
  12. import os
  13. import platform
  14. import re
  15. import sys
  16. import threading
  17. import time
  18. import warnings
  19. from contextlib import contextmanager
  20. from copy import deepcopy
  21. from datetime import datetime, timedelta
  22. from functools import partial, wraps
  23. from types import ModuleType
  24. try:
  25. from unittest import mock
  26. except ImportError:
  27. import mock # noqa
  28. from nose import SkipTest
  29. from kombu import Queue
  30. from kombu.log import NullHandler
  31. from kombu.utils import nested, symbol_by_name
  32. from celery import Celery
  33. from celery.app import current_app
  34. from celery.backends.cache import CacheBackend, DummyClient
  35. from celery.five import (
  36. WhateverIO, builtins, items, reraise,
  37. string_t, values, open_fqdn,
  38. )
  39. from celery.utils.functional import noop
  40. from celery.utils.imports import qualname
  41. __all__ = [
  42. 'Case', 'AppCase', 'Mock', 'patch', 'call', 'skip_unless_module',
  43. 'wrap_logger', 'with_environ', 'sleepdeprived',
  44. 'skip_if_environ', 'todo', 'skip', 'skip_if',
  45. 'skip_unless', 'mask_modules', 'override_stdouts', 'mock_module',
  46. 'replace_module_value', 'sys_platform', 'reset_modules',
  47. 'patch_modules', 'mock_context', 'mock_open', 'patch_many',
  48. 'assert_signal_called', 'skip_if_pypy',
  49. 'skip_if_jython', 'body_from_sig', 'restore_logging',
  50. ]
  51. patch = mock.patch
  52. call = mock.call
  53. CASE_REDEFINES_SETUP = """\
  54. {name} (subclass of AppCase) redefines private "setUp", should be: "setup"\
  55. """
  56. CASE_REDEFINES_TEARDOWN = """\
  57. {name} (subclass of AppCase) redefines private "tearDown", \
  58. should be: "teardown"\
  59. """
  60. CASE_LOG_REDIRECT_EFFECT = """\
  61. Test {0} did not disable LoggingProxy for {1}\
  62. """
  63. CASE_LOG_LEVEL_EFFECT = """\
  64. Test {0} Modified the level of the root logger\
  65. """
  66. CASE_LOG_HANDLER_EFFECT = """\
  67. Test {0} Modified handlers for the root logger\
  68. """
  69. CELERY_TEST_CONFIG = {
  70. #: Don't want log output when running suite.
  71. 'CELERYD_HIJACK_ROOT_LOGGER': False,
  72. 'CELERY_SEND_TASK_ERROR_EMAILS': False,
  73. 'CELERY_DEFAULT_QUEUE': 'testcelery',
  74. 'CELERY_DEFAULT_EXCHANGE': 'testcelery',
  75. 'CELERY_DEFAULT_ROUTING_KEY': 'testcelery',
  76. 'CELERY_QUEUES': (
  77. Queue('testcelery', routing_key='testcelery'),
  78. ),
  79. 'CELERY_ENABLE_UTC': True,
  80. 'CELERY_TIMEZONE': 'UTC',
  81. 'CELERYD_LOG_COLOR': False,
  82. # Mongo results tests (only executed if installed and running)
  83. 'CELERY_MONGODB_BACKEND_SETTINGS': {
  84. 'host': os.environ.get('MONGO_HOST') or 'localhost',
  85. 'port': os.environ.get('MONGO_PORT') or 27017,
  86. 'database': os.environ.get('MONGO_DB') or 'celery_unittests',
  87. 'taskmeta_collection': (os.environ.get('MONGO_TASKMETA_COLLECTION')
  88. or 'taskmeta_collection'),
  89. 'user': os.environ.get('MONGO_USER'),
  90. 'password': os.environ.get('MONGO_PASSWORD'),
  91. }
  92. }
  93. class Trap(object):
  94. def __getattr__(self, name):
  95. raise RuntimeError('Test depends on current_app')
  96. class UnitLogging(symbol_by_name(Celery.log_cls)):
  97. def __init__(self, *args, **kwargs):
  98. super(UnitLogging, self).__init__(*args, **kwargs)
  99. self.already_setup = True
  100. def UnitApp(name=None, broker=None, backend=None,
  101. set_as_current=False, log=UnitLogging, **kwargs):
  102. app = Celery(name or 'celery.tests',
  103. broker=broker or 'memory://',
  104. backend=backend or 'cache+memory://',
  105. set_as_current=set_as_current,
  106. log=log,
  107. **kwargs)
  108. app.add_defaults(deepcopy(CELERY_TEST_CONFIG))
  109. return app
  110. class Mock(mock.Mock):
  111. def __init__(self, *args, **kwargs):
  112. attrs = kwargs.pop('attrs', None) or {}
  113. super(Mock, self).__init__(*args, **kwargs)
  114. for attr_name, attr_value in items(attrs):
  115. setattr(self, attr_name, attr_value)
  116. def skip_unless_module(module):
  117. def _inner(fun):
  118. @wraps(fun)
  119. def __inner(*args, **kwargs):
  120. try:
  121. importlib.import_module(module)
  122. except ImportError:
  123. raise SkipTest('Does not have %s' % (module, ))
  124. return fun(*args, **kwargs)
  125. return __inner
  126. return _inner
  127. # -- adds assertWarns from recent unittest2, not in Python 2.7.
  128. class _AssertRaisesBaseContext(object):
  129. def __init__(self, expected, test_case, callable_obj=None,
  130. expected_regex=None):
  131. self.expected = expected
  132. self.failureException = test_case.failureException
  133. self.obj_name = None
  134. if isinstance(expected_regex, string_t):
  135. expected_regex = re.compile(expected_regex)
  136. self.expected_regex = expected_regex
  137. class _AssertWarnsContext(_AssertRaisesBaseContext):
  138. """A context manager used to implement TestCase.assertWarns* methods."""
  139. def __enter__(self):
  140. # The __warningregistry__'s need to be in a pristine state for tests
  141. # to work properly.
  142. warnings.resetwarnings()
  143. for v in list(values(sys.modules)):
  144. if getattr(v, '__warningregistry__', None):
  145. v.__warningregistry__ = {}
  146. self.warnings_manager = warnings.catch_warnings(record=True)
  147. self.warnings = self.warnings_manager.__enter__()
  148. warnings.simplefilter('always', self.expected)
  149. return self
  150. def __exit__(self, exc_type, exc_value, tb):
  151. self.warnings_manager.__exit__(exc_type, exc_value, tb)
  152. if exc_type is not None:
  153. # let unexpected exceptions pass through
  154. return
  155. try:
  156. exc_name = self.expected.__name__
  157. except AttributeError:
  158. exc_name = str(self.expected)
  159. first_matching = None
  160. for m in self.warnings:
  161. w = m.message
  162. if not isinstance(w, self.expected):
  163. continue
  164. if first_matching is None:
  165. first_matching = w
  166. if (self.expected_regex is not None and
  167. not self.expected_regex.search(str(w))):
  168. continue
  169. # store warning for later retrieval
  170. self.warning = w
  171. self.filename = m.filename
  172. self.lineno = m.lineno
  173. return
  174. # Now we simply try to choose a helpful failure message
  175. if first_matching is not None:
  176. raise self.failureException(
  177. '%r does not match %r' % (
  178. self.expected_regex.pattern, str(first_matching)))
  179. if self.obj_name:
  180. raise self.failureException(
  181. '%s not triggered by %s' % (exc_name, self.obj_name))
  182. else:
  183. raise self.failureException('%s not triggered' % exc_name)
  184. class Case(unittest.TestCase):
  185. def assertWarns(self, expected_warning):
  186. return _AssertWarnsContext(expected_warning, self, None)
  187. def assertWarnsRegex(self, expected_warning, expected_regex):
  188. return _AssertWarnsContext(expected_warning, self,
  189. None, expected_regex)
  190. def assertDictContainsSubset(self, expected, actual, msg=None):
  191. missing, mismatched = [], []
  192. for key, value in items(expected):
  193. if key not in actual:
  194. missing.append(key)
  195. elif value != actual[key]:
  196. mismatched.append('%s, expected: %s, actual: %s' % (
  197. safe_repr(key), safe_repr(value),
  198. safe_repr(actual[key])))
  199. if not (missing or mismatched):
  200. return
  201. standard_msg = ''
  202. if missing:
  203. standard_msg = 'Missing: %s' % ','.join(map(safe_repr, missing))
  204. if mismatched:
  205. if standard_msg:
  206. standard_msg += '; '
  207. standard_msg += 'Mismatched values: %s' % (
  208. ','.join(mismatched))
  209. self.fail(self._formatMessage(msg, standard_msg))
  210. def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
  211. missing = unexpected = None
  212. try:
  213. expected = sorted(expected_seq)
  214. actual = sorted(actual_seq)
  215. except TypeError:
  216. # Unsortable items (example: set(), complex(), ...)
  217. expected = list(expected_seq)
  218. actual = list(actual_seq)
  219. missing, unexpected = unorderable_list_difference(
  220. expected, actual)
  221. else:
  222. return self.assertSequenceEqual(expected, actual, msg=msg)
  223. errors = []
  224. if missing:
  225. errors.append(
  226. 'Expected, but missing:\n %s' % (safe_repr(missing), )
  227. )
  228. if unexpected:
  229. errors.append(
  230. 'Unexpected, but present:\n %s' % (safe_repr(unexpected), )
  231. )
  232. if errors:
  233. standardMsg = '\n'.join(errors)
  234. self.fail(self._formatMessage(msg, standardMsg))
  235. def depends_on_current_app(fun):
  236. if inspect.isclass(fun):
  237. fun.contained = False
  238. else:
  239. @wraps(fun)
  240. def __inner(self, *args, **kwargs):
  241. self.app.set_current()
  242. return fun(self, *args, **kwargs)
  243. return __inner
  244. class AppCase(Case):
  245. contained = True
  246. def __init__(self, *args, **kwargs):
  247. super(AppCase, self).__init__(*args, **kwargs)
  248. if self.__class__.__dict__.get('setUp'):
  249. raise RuntimeError(
  250. CASE_REDEFINES_SETUP.format(name=qualname(self)),
  251. )
  252. if self.__class__.__dict__.get('tearDown'):
  253. raise RuntimeError(
  254. CASE_REDEFINES_TEARDOWN.format(name=qualname(self)),
  255. )
  256. def Celery(self, *args, **kwargs):
  257. return UnitApp(*args, **kwargs)
  258. def setUp(self):
  259. self._threads_at_setup = list(threading.enumerate())
  260. from celery import _state
  261. self._current_app = current_app()
  262. self._default_app = _state.default_app
  263. trap = Trap()
  264. _state.set_default_app(trap)
  265. _state._tls.current_app = trap
  266. self.app = self.Celery(set_as_current=False)
  267. if not self.contained:
  268. self.app.set_current()
  269. root = logging.getLogger()
  270. self.__rootlevel = root.level
  271. self.__roothandlers = root.handlers
  272. try:
  273. self.setup()
  274. except:
  275. self._teardown_app()
  276. raise
  277. def _teardown_app(self):
  278. from celery.utils.log import LoggingProxy
  279. assert sys.stdout
  280. assert sys.stderr
  281. assert sys.__stdout__
  282. assert sys.__stderr__
  283. this = self._get_test_name()
  284. if isinstance(sys.stdout, LoggingProxy) or \
  285. isinstance(sys.__stdout__, LoggingProxy):
  286. raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stdout'))
  287. if isinstance(sys.stderr, LoggingProxy) or \
  288. isinstance(sys.__stderr__, LoggingProxy):
  289. raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stderr'))
  290. backend = self.app.__dict__.get('backend')
  291. if backend is not None:
  292. if isinstance(backend, CacheBackend):
  293. if isinstance(backend.client, DummyClient):
  294. backend.client.cache.clear()
  295. backend._cache.clear()
  296. from celery._state import _tls, set_default_app
  297. set_default_app(self._default_app)
  298. _tls.current_app = self._current_app
  299. if self.app is not self._current_app:
  300. self.app.close()
  301. self.app = None
  302. self.assertEqual(
  303. self._threads_at_setup, list(threading.enumerate()),
  304. )
  305. def _get_test_name(self):
  306. return '.'.join([self.__class__.__name__, self._testMethodName])
  307. def tearDown(self):
  308. try:
  309. self.teardown()
  310. finally:
  311. self._teardown_app()
  312. self.assert_no_logging_side_effect()
  313. def assert_no_logging_side_effect(self):
  314. this = self._get_test_name()
  315. root = logging.getLogger()
  316. if root.level != self.__rootlevel:
  317. raise RuntimeError(CASE_LOG_LEVEL_EFFECT.format(this))
  318. if root.handlers != self.__roothandlers:
  319. raise RuntimeError(CASE_LOG_HANDLER_EFFECT.format(this))
  320. def setup(self):
  321. pass
  322. def teardown(self):
  323. pass
  324. def get_handlers(logger):
  325. return [h for h in logger.handlers if not isinstance(h, NullHandler)]
  326. @contextmanager
  327. def wrap_logger(logger, loglevel=logging.ERROR):
  328. old_handlers = get_handlers(logger)
  329. sio = WhateverIO()
  330. siohandler = logging.StreamHandler(sio)
  331. logger.handlers = [siohandler]
  332. try:
  333. yield sio
  334. finally:
  335. logger.handlers = old_handlers
  336. def with_environ(env_name, env_value):
  337. def _envpatched(fun):
  338. @wraps(fun)
  339. def _patch_environ(*args, **kwargs):
  340. prev_val = os.environ.get(env_name)
  341. os.environ[env_name] = env_value
  342. try:
  343. return fun(*args, **kwargs)
  344. finally:
  345. os.environ[env_name] = prev_val or ''
  346. return _patch_environ
  347. return _envpatched
  348. def sleepdeprived(module=time):
  349. def _sleepdeprived(fun):
  350. @wraps(fun)
  351. def __sleepdeprived(*args, **kwargs):
  352. old_sleep = module.sleep
  353. module.sleep = noop
  354. try:
  355. return fun(*args, **kwargs)
  356. finally:
  357. module.sleep = old_sleep
  358. return __sleepdeprived
  359. return _sleepdeprived
  360. def skip_if_environ(env_var_name):
  361. def _wrap_test(fun):
  362. @wraps(fun)
  363. def _skips_if_environ(*args, **kwargs):
  364. if os.environ.get(env_var_name):
  365. raise SkipTest('SKIP %s: %s set\n' % (
  366. fun.__name__, env_var_name))
  367. return fun(*args, **kwargs)
  368. return _skips_if_environ
  369. return _wrap_test
  370. def _skip_test(reason, sign):
  371. def _wrap_test(fun):
  372. @wraps(fun)
  373. def _skipped_test(*args, **kwargs):
  374. raise SkipTest('%s: %s' % (sign, reason))
  375. return _skipped_test
  376. return _wrap_test
  377. def todo(reason):
  378. """TODO test decorator."""
  379. return _skip_test(reason, 'TODO')
  380. def skip(reason):
  381. """Skip test decorator."""
  382. return _skip_test(reason, 'SKIP')
  383. def skip_if(predicate, reason):
  384. """Skip test if predicate is :const:`True`."""
  385. def _inner(fun):
  386. return predicate and skip(reason)(fun) or fun
  387. return _inner
  388. def skip_unless(predicate, reason):
  389. """Skip test if predicate is :const:`False`."""
  390. return skip_if(not predicate, reason)
  391. # Taken from
  392. # http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
  393. @contextmanager
  394. def mask_modules(*modnames):
  395. """Ban some modules from being importable inside the context
  396. For example:
  397. >>> with missing_modules('sys'):
  398. ... try:
  399. ... import sys
  400. ... except ImportError:
  401. ... print 'sys not found'
  402. sys not found
  403. >>> import sys
  404. >>> sys.version
  405. (2, 5, 2, 'final', 0)
  406. """
  407. realimport = builtins.__import__
  408. def myimp(name, *args, **kwargs):
  409. if name in modnames:
  410. raise ImportError('No module named %s' % name)
  411. else:
  412. return realimport(name, *args, **kwargs)
  413. builtins.__import__ = myimp
  414. try:
  415. yield True
  416. finally:
  417. builtins.__import__ = realimport
  418. @contextmanager
  419. def override_stdouts():
  420. """Override `sys.stdout` and `sys.stderr` with `WhateverIO`."""
  421. prev_out, prev_err = sys.stdout, sys.stderr
  422. mystdout, mystderr = WhateverIO(), WhateverIO()
  423. sys.stdout = sys.__stdout__ = mystdout
  424. sys.stderr = sys.__stderr__ = mystderr
  425. try:
  426. yield mystdout, mystderr
  427. finally:
  428. sys.stdout = sys.__stdout__ = prev_out
  429. sys.stderr = sys.__stderr__ = prev_err
  430. def _old_patch(module, name, mocked):
  431. module = importlib.import_module(module)
  432. def _patch(fun):
  433. @wraps(fun)
  434. def __patched(*args, **kwargs):
  435. prev = getattr(module, name)
  436. setattr(module, name, mocked)
  437. try:
  438. return fun(*args, **kwargs)
  439. finally:
  440. setattr(module, name, prev)
  441. return __patched
  442. return _patch
  443. @contextmanager
  444. def replace_module_value(module, name, value=None):
  445. has_prev = hasattr(module, name)
  446. prev = getattr(module, name, None)
  447. if value:
  448. setattr(module, name, value)
  449. else:
  450. try:
  451. delattr(module, name)
  452. except AttributeError:
  453. pass
  454. try:
  455. yield
  456. finally:
  457. if prev is not None:
  458. setattr(sys, name, prev)
  459. if not has_prev:
  460. try:
  461. delattr(module, name)
  462. except AttributeError:
  463. pass
  464. pypy_version = partial(
  465. replace_module_value, sys, 'pypy_version_info',
  466. )
  467. platform_pyimp = partial(
  468. replace_module_value, platform, 'python_implementation',
  469. )
  470. @contextmanager
  471. def sys_platform(value):
  472. prev, sys.platform = sys.platform, value
  473. try:
  474. yield
  475. finally:
  476. sys.platform = prev
  477. @contextmanager
  478. def reset_modules(*modules):
  479. prev = dict((k, sys.modules.pop(k)) for k in modules if k in sys.modules)
  480. try:
  481. yield
  482. finally:
  483. sys.modules.update(prev)
  484. @contextmanager
  485. def patch_modules(*modules):
  486. prev = {}
  487. for mod in modules:
  488. prev[mod] = sys.modules.get(mod)
  489. sys.modules[mod] = ModuleType(mod)
  490. try:
  491. yield
  492. finally:
  493. for name, mod in items(prev):
  494. if mod is None:
  495. sys.modules.pop(name, None)
  496. else:
  497. sys.modules[name] = mod
  498. @contextmanager
  499. def mock_module(*names):
  500. prev = {}
  501. class MockModule(ModuleType):
  502. def __getattr__(self, attr):
  503. setattr(self, attr, Mock())
  504. return ModuleType.__getattribute__(self, attr)
  505. mods = []
  506. for name in names:
  507. try:
  508. prev[name] = sys.modules[name]
  509. except KeyError:
  510. pass
  511. mod = sys.modules[name] = MockModule(name)
  512. mods.append(mod)
  513. try:
  514. yield mods
  515. finally:
  516. for name in names:
  517. try:
  518. sys.modules[name] = prev[name]
  519. except KeyError:
  520. try:
  521. del(sys.modules[name])
  522. except KeyError:
  523. pass
  524. @contextmanager
  525. def mock_context(mock, typ=Mock):
  526. context = mock.return_value = Mock()
  527. context.__enter__ = typ()
  528. context.__exit__ = typ()
  529. def on_exit(*x):
  530. if x[0]:
  531. reraise(x[0], x[1], x[2])
  532. context.__exit__.side_effect = on_exit
  533. context.__enter__.return_value = context
  534. try:
  535. yield context
  536. finally:
  537. context.reset()
  538. @contextmanager
  539. def mock_open(typ=WhateverIO, side_effect=None):
  540. with patch(open_fqdn) as open_:
  541. with mock_context(open_) as context:
  542. if side_effect is not None:
  543. context.__enter__.side_effect = side_effect
  544. val = context.__enter__.return_value = typ()
  545. val.__exit__ = Mock()
  546. yield val
  547. def patch_many(*targets):
  548. return nested(*[patch(target) for target in targets])
  549. @contextmanager
  550. def assert_signal_called(signal, **expected):
  551. handler = Mock()
  552. call_handler = partial(handler)
  553. signal.connect(call_handler)
  554. try:
  555. yield handler
  556. finally:
  557. signal.disconnect(call_handler)
  558. handler.assert_called_with(signal=signal, **expected)
  559. def skip_if_pypy(fun):
  560. @wraps(fun)
  561. def _inner(*args, **kwargs):
  562. if getattr(sys, 'pypy_version_info', None):
  563. raise SkipTest('does not work on PyPy')
  564. return fun(*args, **kwargs)
  565. return _inner
  566. def skip_if_jython(fun):
  567. @wraps(fun)
  568. def _inner(*args, **kwargs):
  569. if sys.platform.startswith('java'):
  570. raise SkipTest('does not work on Jython')
  571. return fun(*args, **kwargs)
  572. return _inner
  573. def body_from_sig(app, sig, utc=True):
  574. sig.freeze()
  575. callbacks = sig.options.pop('link', None)
  576. errbacks = sig.options.pop('link_error', None)
  577. countdown = sig.options.pop('countdown', None)
  578. if countdown:
  579. eta = app.now() + timedelta(seconds=countdown)
  580. else:
  581. eta = sig.options.pop('eta', None)
  582. if eta and isinstance(eta, datetime):
  583. eta = eta.isoformat()
  584. expires = sig.options.pop('expires', None)
  585. if expires and isinstance(expires, int):
  586. expires = app.now() + timedelta(seconds=expires)
  587. if expires and isinstance(expires, datetime):
  588. expires = expires.isoformat()
  589. return {
  590. 'task': sig.task,
  591. 'id': sig.id,
  592. 'args': sig.args,
  593. 'kwargs': sig.kwargs,
  594. 'callbacks': [dict(s) for s in callbacks] if callbacks else None,
  595. 'errbacks': [dict(s) for s in errbacks] if errbacks else None,
  596. 'eta': eta,
  597. 'utc': utc,
  598. 'expires': expires,
  599. }
  600. @contextmanager
  601. def restore_logging():
  602. outs = sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__
  603. root = logging.getLogger()
  604. level = root.level
  605. handlers = root.handlers
  606. try:
  607. yield
  608. finally:
  609. sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__ = outs
  610. root.level = level
  611. root.handlers[:] = handlers