case.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. from __future__ import absolute_import
  2. import atexit
  3. import logging
  4. import os
  5. import signal
  6. import socket
  7. import sys
  8. import traceback
  9. from itertools import count
  10. from time import time
  11. from celery import current_app
  12. from celery.exceptions import TimeoutError
  13. from celery.app.control import flatten_reply
  14. from celery.utils.imports import qualname
  15. from celery.tests.case import Case
  16. HOSTNAME = socket.gethostname()
  17. def say(msg):
  18. sys.stderr.write('%s\n' % msg)
  19. def try_while(fun, reason='Timed out', timeout=10, interval=0.5):
  20. time_start = time()
  21. for iterations in count(0):
  22. if time() - time_start >= timeout:
  23. raise TimeoutError()
  24. ret = fun()
  25. if ret:
  26. return ret
  27. class Worker(object):
  28. started = False
  29. worker_ids = count(1)
  30. _shutdown_called = False
  31. def __init__(self, hostname, loglevel='error', app=None):
  32. self.hostname = hostname
  33. self.loglevel = loglevel
  34. self.app = app or current_app._get_current_object()
  35. def start(self):
  36. if not self.started:
  37. self._fork_and_exec()
  38. self.started = True
  39. def _fork_and_exec(self):
  40. pid = os.fork()
  41. if pid == 0:
  42. self.app.worker_main(['worker', '--loglevel=INFO',
  43. '-n', self.hostname,
  44. '-P', 'solo'])
  45. os._exit(0)
  46. self.pid = pid
  47. def ping(self, *args, **kwargs):
  48. return self.app.control.ping(*args, **kwargs)
  49. def is_alive(self, timeout=1):
  50. r = self.ping(destination=[self.hostname], timeout=timeout)
  51. return self.hostname in flatten_reply(r)
  52. def wait_until_started(self, timeout=10, interval=0.5):
  53. try_while(
  54. lambda: self.is_alive(interval),
  55. "Worker won't start (after %s secs.)" % timeout,
  56. interval=interval, timeout=timeout,
  57. )
  58. say('--WORKER %s IS ONLINE--' % self.hostname)
  59. def ensure_shutdown(self, timeout=10, interval=0.5):
  60. os.kill(self.pid, signal.SIGTERM)
  61. try_while(
  62. lambda: not self.is_alive(interval),
  63. "Worker won't shutdown (after %s secs.)" % timeout,
  64. timeout=10, interval=0.5,
  65. )
  66. say('--WORKER %s IS SHUTDOWN--' % self.hostname)
  67. self._shutdown_called = True
  68. def ensure_started(self):
  69. self.start()
  70. self.wait_until_started()
  71. @classmethod
  72. def managed(cls, hostname=None, caller=None):
  73. hostname = hostname or socket.gethostname()
  74. if caller:
  75. hostname = '.'.join([qualname(caller), hostname])
  76. else:
  77. hostname += str(next(cls.worker_ids()))
  78. worker = cls(hostname)
  79. worker.ensure_started()
  80. stack = traceback.format_stack()
  81. @atexit.register
  82. def _ensure_shutdown_once():
  83. if not worker._shutdown_called:
  84. say('-- Found worker not stopped at shutdown: %s\n%s' % (
  85. worker.hostname,
  86. '\n'.join(stack)))
  87. worker.ensure_shutdown()
  88. return worker
  89. class WorkerCase(Case):
  90. hostname = HOSTNAME
  91. worker = None
  92. @classmethod
  93. def setUpClass(cls):
  94. logging.getLogger('amqp').setLevel(logging.ERROR)
  95. cls.worker = Worker.managed(cls.hostname, caller=cls)
  96. @classmethod
  97. def tearDownClass(cls):
  98. cls.worker.ensure_shutdown()
  99. def assertWorkerAlive(self, timeout=1):
  100. self.assertTrue(self.worker.is_alive)
  101. def inspect(self, timeout=1):
  102. return self.app.control.inspect([self.worker.hostname],
  103. timeout=timeout)
  104. def my_response(self, response):
  105. return flatten_reply(response)[self.worker.hostname]
  106. def is_accepted(self, task_id, interval=0.5):
  107. active = self.inspect(timeout=interval).active()
  108. if active:
  109. for task in active[self.worker.hostname]:
  110. if task['id'] == task_id:
  111. return True
  112. return False
  113. def is_reserved(self, task_id, interval=0.5):
  114. reserved = self.inspect(timeout=interval).reserved()
  115. if reserved:
  116. for task in reserved[self.worker.hostname]:
  117. if task['id'] == task_id:
  118. return True
  119. return False
  120. def is_scheduled(self, task_id, interval=0.5):
  121. schedule = self.inspect(timeout=interval).scheduled()
  122. if schedule:
  123. for item in schedule[self.worker.hostname]:
  124. if item['request']['id'] == task_id:
  125. return True
  126. return False
  127. def is_received(self, task_id, interval=0.5):
  128. return (self.is_reserved(task_id, interval) or
  129. self.is_scheduled(task_id, interval) or
  130. self.is_accepted(task_id, interval))
  131. def ensure_accepted(self, task_id, interval=0.5, timeout=10):
  132. return try_while(lambda: self.is_accepted(task_id, interval),
  133. 'Task not accepted within timeout',
  134. interval=0.5, timeout=10)
  135. def ensure_received(self, task_id, interval=0.5, timeout=10):
  136. return try_while(lambda: self.is_received(task_id, interval),
  137. 'Task not receied within timeout',
  138. interval=0.5, timeout=10)
  139. def ensure_scheduled(self, task_id, interval=0.5, timeout=10):
  140. return try_while(lambda: self.is_scheduled(task_id, interval),
  141. 'Task not scheduled within timeout',
  142. interval=0.5, timeout=10)