from __future__ import absolute_import import atexit import logging import os import signal import socket import sys import traceback from itertools import count from time import time from celery import current_app from celery.exceptions import TimeoutError from celery.app.control import flatten_reply from celery.utils.imports import qualname from celery.tests.case import Case HOSTNAME = socket.gethostname() def say(msg): sys.stderr.write('%s\n' % msg) def try_while(fun, reason='Timed out', timeout=10, interval=0.5): time_start = time() for iterations in count(0): if time() - time_start >= timeout: raise TimeoutError() ret = fun() if ret: return ret class Worker(object): started = False worker_ids = count(1) _shutdown_called = False def __init__(self, hostname, loglevel='error', app=None): self.hostname = hostname self.loglevel = loglevel self.app = app or current_app._get_current_object() def start(self): if not self.started: self._fork_and_exec() self.started = True def _fork_and_exec(self): pid = os.fork() if pid == 0: self.app.worker_main(['worker', '--loglevel=INFO', '-n', self.hostname, '-P', 'solo']) os._exit(0) self.pid = pid def ping(self, *args, **kwargs): return self.app.control.ping(*args, **kwargs) def is_alive(self, timeout=1): r = self.ping(destination=[self.hostname], timeout=timeout) return self.hostname in flatten_reply(r) def wait_until_started(self, timeout=10, interval=0.5): try_while( lambda: self.is_alive(interval), "Worker won't start (after %s secs.)" % timeout, interval=interval, timeout=timeout, ) say('--WORKER %s IS ONLINE--' % self.hostname) def ensure_shutdown(self, timeout=10, interval=0.5): os.kill(self.pid, signal.SIGTERM) try_while( lambda: not self.is_alive(interval), "Worker won't shutdown (after %s secs.)" % timeout, timeout=10, interval=0.5, ) say('--WORKER %s IS SHUTDOWN--' % self.hostname) self._shutdown_called = True def ensure_started(self): self.start() self.wait_until_started() @classmethod def managed(cls, hostname=None, caller=None): hostname = hostname or socket.gethostname() if caller: hostname = '.'.join([qualname(caller), hostname]) else: hostname += str(next(cls.worker_ids())) worker = cls(hostname) worker.ensure_started() stack = traceback.format_stack() @atexit.register def _ensure_shutdown_once(): if not worker._shutdown_called: say('-- Found worker not stopped at shutdown: %s\n%s' % ( worker.hostname, '\n'.join(stack))) worker.ensure_shutdown() return worker class WorkerCase(Case): hostname = HOSTNAME worker = None @classmethod def setUpClass(cls): logging.getLogger('amqp').setLevel(logging.ERROR) cls.worker = Worker.managed(cls.hostname, caller=cls) @classmethod def tearDownClass(cls): cls.worker.ensure_shutdown() def assertWorkerAlive(self, timeout=1): self.assertTrue(self.worker.is_alive) def inspect(self, timeout=1): return self.app.control.inspect([self.worker.hostname], timeout=timeout) def my_response(self, response): return flatten_reply(response)[self.worker.hostname] def is_accepted(self, task_id, interval=0.5): active = self.inspect(timeout=interval).active() if active: for task in active[self.worker.hostname]: if task['id'] == task_id: return True return False def is_reserved(self, task_id, interval=0.5): reserved = self.inspect(timeout=interval).reserved() if reserved: for task in reserved[self.worker.hostname]: if task['id'] == task_id: return True return False def is_scheduled(self, task_id, interval=0.5): schedule = self.inspect(timeout=interval).scheduled() if schedule: for item in schedule[self.worker.hostname]: if item['request']['id'] == task_id: return True return False def is_received(self, task_id, interval=0.5): return (self.is_reserved(task_id, interval) or self.is_scheduled(task_id, interval) or self.is_accepted(task_id, interval)) def ensure_accepted(self, task_id, interval=0.5, timeout=10): return try_while(lambda: self.is_accepted(task_id, interval), 'Task not accepted within timeout', interval=0.5, timeout=10) def ensure_received(self, task_id, interval=0.5, timeout=10): return try_while(lambda: self.is_received(task_id, interval), 'Task not receied within timeout', interval=0.5, timeout=10) def ensure_scheduled(self, task_id, interval=0.5, timeout=10): return try_while(lambda: self.is_scheduled(task_id, interval), 'Task not scheduled within timeout', interval=0.5, timeout=10)