from __future__ import absolute_import import os import socket from collections import deque from datetime import datetime, timedelta from threading import Event from amqp import ChannelError from kombu import Connection from kombu.common import QoS, ignore_errors from kombu.transport.base import Message from mock import Mock, patch from celery.app.defaults import DEFAULTS from celery.bootsteps import RUN, CLOSE, StartStopStep from celery.concurrency.base import BasePool from celery.datastructures import AttributeDict from celery.exceptions import SystemTerminate, TaskRevokedError from celery.five import Empty, range, Queue as FastQueue from celery.utils import uuid from celery.worker import components from celery.worker import consumer from celery.worker.consumer import Consumer as __Consumer from celery.worker.job import Request from celery.utils import worker_direct from celery.utils.serialization import pickle from celery.utils.timer2 import Timer from celery.tests.case import AppCase, restore_logging def MockStep(step=None): step = Mock() if step is None else step step.blueprint = Mock() step.blueprint.name = 'MockNS' step.name = 'MockStep(%s)' % (id(step), ) return step class PlaceHolder(object): pass def find_step(obj, typ): return obj.blueprint.steps[typ.name] class Consumer(__Consumer): def __init__(self, *args, **kwargs): kwargs.setdefault('without_mingle', True) # disable Mingle step kwargs.setdefault('without_gossip', True) # disable Gossip step kwargs.setdefault('without_heartbeat', True) # disable Heart step super(Consumer, self).__init__(*args, **kwargs) class _MyKombuConsumer(Consumer): broadcast_consumer = Mock() task_consumer = Mock() def __init__(self, *args, **kwargs): kwargs.setdefault('pool', BasePool(2)) super(_MyKombuConsumer, self).__init__(*args, **kwargs) def restart_heartbeat(self): self.heart = None class MyKombuConsumer(Consumer): def loop(self, *args, **kwargs): pass class MockNode(object): commands = [] def handle_message(self, body, message): self.commands.append(body.pop('command', None)) class MockEventDispatcher(object): sent = [] closed = False flushed = False _outbound_buffer = [] def send(self, event, *args, **kwargs): self.sent.append(event) def close(self): self.closed = True def flush(self): self.flushed = True class MockHeart(object): closed = False def stop(self): self.closed = True def create_message(channel, **data): data.setdefault('id', uuid()) channel.no_ack_consumers = set() m = Message(channel, body=pickle.dumps(dict(**data)), content_type='application/x-python-serialize', content_encoding='binary', delivery_info={'consumer_tag': 'mock'}) m.accept = ['application/x-python-serialize'] return m class test_Consumer(AppCase): def setup(self): self.buffer = FastQueue() self.timer = Timer() @self.app.task(shared=False) def foo_task(x, y, z): return x * y * z self.foo_task = foo_task def teardown(self): self.timer.stop() def test_info(self): l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) l.task_consumer = Mock() l.qos = QoS(l.task_consumer.qos, 10) l.connection = Mock() l.connection.info.return_value = {'foo': 'bar'} l.controller = l.app.WorkController() l.controller.pool = Mock() l.controller.pool.info.return_value = [Mock(), Mock()] l.controller.consumer = l info = l.controller.stats() self.assertEqual(info['prefetch_count'], 10) self.assertTrue(info['broker']) def test_start_when_closed(self): l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) l.blueprint.state = CLOSE l.start() def test_connection(self): l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) l.blueprint.start(l) self.assertIsInstance(l.connection, Connection) l.blueprint.state = RUN l.event_dispatcher = None l.blueprint.restart(l) self.assertTrue(l.connection) l.blueprint.state = RUN l.shutdown() self.assertIsNone(l.connection) self.assertIsNone(l.task_consumer) l.blueprint.start(l) self.assertIsInstance(l.connection, Connection) l.blueprint.restart(l) l.stop() l.shutdown() self.assertIsNone(l.connection) self.assertIsNone(l.task_consumer) def test_close_connection(self): l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) l.blueprint.state = RUN step = find_step(l, consumer.Connection) conn = l.connection = Mock() step.shutdown(l) self.assertTrue(conn.close.called) self.assertIsNone(l.connection) l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) eventer = l.event_dispatcher = Mock() eventer.enabled = True heart = l.heart = MockHeart() l.blueprint.state = RUN Events = find_step(l, consumer.Events) Events.shutdown(l) Heart = find_step(l, consumer.Heart) Heart.shutdown(l) self.assertTrue(eventer.close.call_count) self.assertTrue(heart.closed) @patch('celery.worker.consumer.warn') def test_receive_message_unknown(self, warn): l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) l.blueprint.state = RUN l.steps.pop() backend = Mock() m = create_message(backend, unknown={'baz': '!!!'}) l.event_dispatcher = Mock() l.node = MockNode() callback = self._get_on_message(l) callback(m.decode(), m) self.assertTrue(warn.call_count) @patch('celery.worker.strategy.to_timestamp') def test_receive_message_eta_OverflowError(self, to_timestamp): to_timestamp.side_effect = OverflowError() l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) l.blueprint.state = RUN l.steps.pop() m = create_message(Mock(), task=self.foo_task.name, args=('2, 2'), kwargs={}, eta=datetime.now().isoformat()) l.event_dispatcher = Mock() l.node = MockNode() l.update_strategies() l.qos = Mock() callback = self._get_on_message(l) callback(m.decode(), m) self.assertTrue(m.acknowledged) @patch('celery.worker.consumer.error') def test_receive_message_InvalidTaskError(self, error): l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) l.blueprint.state = RUN l.event_dispatcher = Mock() l.steps.pop() m = create_message(Mock(), task=self.foo_task.name, args=(1, 2), kwargs='foobarbaz', id=1) l.update_strategies() l.event_dispatcher = Mock() callback = self._get_on_message(l) callback(m.decode(), m) self.assertIn('Received invalid task message', error.call_args[0][0]) @patch('celery.worker.consumer.crit') def test_on_decode_error(self, crit): l = Consumer(self.buffer.put, timer=self.timer, app=self.app) class MockMessage(Mock): content_type = 'application/x-msgpack' content_encoding = 'binary' body = 'foobarbaz' message = MockMessage() l.on_decode_error(message, KeyError('foo')) self.assertTrue(message.ack.call_count) self.assertIn("Can't decode message body", crit.call_args[0][0]) def _get_on_message(self, l): if l.qos is None: l.qos = Mock() l.event_dispatcher = Mock() l.task_consumer = Mock() l.connection = Mock() l.connection.drain_events.side_effect = SystemExit() with self.assertRaises(SystemExit): l.loop(*l.loop_args()) self.assertTrue(l.task_consumer.register_callback.called) return l.task_consumer.register_callback.call_args[0][0] def test_receieve_message(self): l = Consumer(self.buffer.put, timer=self.timer, app=self.app) l.blueprint.state = RUN l.event_dispatcher = Mock() m = create_message(Mock(), task=self.foo_task.name, args=[2, 4, 8], kwargs={}) l.update_strategies() callback = self._get_on_message(l) callback(m.decode(), m) in_bucket = self.buffer.get_nowait() self.assertIsInstance(in_bucket, Request) self.assertEqual(in_bucket.name, self.foo_task.name) self.assertEqual(in_bucket.execute(), 2 * 4 * 8) self.assertTrue(self.timer.empty()) def test_start_channel_error(self): class MockConsumer(Consumer): iterations = 0 def loop(self, *args, **kwargs): if not self.iterations: self.iterations = 1 raise KeyError('foo') raise SyntaxError('bar') l = MockConsumer(self.buffer.put, timer=self.timer, send_events=False, pool=BasePool(), app=self.app) l.channel_errors = (KeyError, ) with self.assertRaises(KeyError): l.start() l.timer.stop() def test_start_connection_error(self): class MockConsumer(Consumer): iterations = 0 def loop(self, *args, **kwargs): if not self.iterations: self.iterations = 1 raise KeyError('foo') raise SyntaxError('bar') l = MockConsumer(self.buffer.put, timer=self.timer, send_events=False, pool=BasePool(), app=self.app) l.connection_errors = (KeyError, ) self.assertRaises(SyntaxError, l.start) l.timer.stop() def test_loop_ignores_socket_timeout(self): class Connection(self.app.connection().__class__): obj = None def drain_events(self, **kwargs): self.obj.connection = None raise socket.timeout(10) l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) l.connection = Connection() l.task_consumer = Mock() l.connection.obj = l l.qos = QoS(l.task_consumer.qos, 10) l.loop(*l.loop_args()) def test_loop_when_socket_error(self): class Connection(self.app.connection().__class__): obj = None def drain_events(self, **kwargs): self.obj.connection = None raise socket.error('foo') l = Consumer(self.buffer.put, timer=self.timer, app=self.app) l.blueprint.state = RUN c = l.connection = Connection() l.connection.obj = l l.task_consumer = Mock() l.qos = QoS(l.task_consumer.qos, 10) with self.assertRaises(socket.error): l.loop(*l.loop_args()) l.blueprint.state = CLOSE l.connection = c l.loop(*l.loop_args()) def test_loop(self): class Connection(self.app.connection().__class__): obj = None def drain_events(self, **kwargs): self.obj.connection = None l = Consumer(self.buffer.put, timer=self.timer, app=self.app) l.blueprint.state = RUN l.connection = Connection() l.connection.obj = l l.task_consumer = Mock() l.qos = QoS(l.task_consumer.qos, 10) l.loop(*l.loop_args()) l.loop(*l.loop_args()) self.assertTrue(l.task_consumer.consume.call_count) l.task_consumer.qos.assert_called_with(prefetch_count=10) self.assertEqual(l.qos.value, 10) l.qos.decrement_eventually() self.assertEqual(l.qos.value, 9) l.qos.update() self.assertEqual(l.qos.value, 9) l.task_consumer.qos.assert_called_with(prefetch_count=9) def test_ignore_errors(self): l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) l.connection_errors = (AttributeError, KeyError, ) l.channel_errors = (SyntaxError, ) ignore_errors(l, Mock(side_effect=AttributeError('foo'))) ignore_errors(l, Mock(side_effect=KeyError('foo'))) ignore_errors(l, Mock(side_effect=SyntaxError('foo'))) with self.assertRaises(IndexError): ignore_errors(l, Mock(side_effect=IndexError('foo'))) def test_apply_eta_task(self): from celery.worker import state l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) l.qos = QoS(None, 10) task = object() qos = l.qos.value l.apply_eta_task(task) self.assertIn(task, state.reserved_requests) self.assertEqual(l.qos.value, qos - 1) self.assertIs(self.buffer.get_nowait(), task) def test_receieve_message_eta_isoformat(self): l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) l.blueprint.state = RUN l.steps.pop() m = create_message( Mock(), task=self.foo_task.name, eta=(datetime.now() + timedelta(days=1)).isoformat(), args=[2, 4, 8], kwargs={}, ) l.task_consumer = Mock() l.qos = QoS(l.task_consumer.qos, 1) current_pcount = l.qos.value l.event_dispatcher = Mock() l.enabled = False l.update_strategies() callback = self._get_on_message(l) callback(m.decode(), m) l.timer.stop() l.timer.join(1) items = [entry[2] for entry in self.timer.queue] found = 0 for item in items: if item.args[0].name == self.foo_task.name: found = True self.assertTrue(found) self.assertGreater(l.qos.value, current_pcount) l.timer.stop() def test_pidbox_callback(self): l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) con = find_step(l, consumer.Control).box con.node = Mock() con.reset = Mock() con.on_message('foo', 'bar') con.node.handle_message.assert_called_with('foo', 'bar') con.node = Mock() con.node.handle_message.side_effect = KeyError('foo') con.on_message('foo', 'bar') con.node.handle_message.assert_called_with('foo', 'bar') con.node = Mock() con.node.handle_message.side_effect = ValueError('foo') con.on_message('foo', 'bar') con.node.handle_message.assert_called_with('foo', 'bar') self.assertTrue(con.reset.called) def test_revoke(self): l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) l.blueprint.state = RUN l.steps.pop() backend = Mock() id = uuid() t = create_message(backend, task=self.foo_task.name, args=[2, 4, 8], kwargs={}, id=id) from celery.worker.state import revoked revoked.add(id) callback = self._get_on_message(l) callback(t.decode(), t) self.assertTrue(self.buffer.empty()) def test_receieve_message_not_registered(self): l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) l.blueprint.state = RUN l.steps.pop() backend = Mock() m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={}) l.event_dispatcher = Mock() callback = self._get_on_message(l) self.assertFalse(callback(m.decode(), m)) with self.assertRaises(Empty): self.buffer.get_nowait() self.assertTrue(self.timer.empty()) @patch('celery.worker.consumer.warn') @patch('celery.worker.consumer.logger') def test_receieve_message_ack_raises(self, logger, warn): l = Consumer(self.buffer.put, timer=self.timer, app=self.app) l.blueprint.state = RUN backend = Mock() m = create_message(backend, args=[2, 4, 8], kwargs={}) l.event_dispatcher = Mock() l.connection_errors = (socket.error, ) m.reject = Mock() m.reject.side_effect = socket.error('foo') callback = self._get_on_message(l) self.assertFalse(callback(m.decode(), m)) self.assertTrue(warn.call_count) with self.assertRaises(Empty): self.buffer.get_nowait() self.assertTrue(self.timer.empty()) m.reject.assert_called_with(requeue=False) self.assertTrue(logger.critical.call_count) def test_receive_message_eta(self): l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) l.steps.pop() l.event_dispatcher = Mock() l.event_dispatcher._outbound_buffer = deque() backend = Mock() m = create_message( backend, task=self.foo_task.name, args=[2, 4, 8], kwargs={}, eta=(datetime.now() + timedelta(days=1)).isoformat(), ) try: l.blueprint.start(l) p = l.app.conf.BROKER_CONNECTION_RETRY l.app.conf.BROKER_CONNECTION_RETRY = False l.blueprint.start(l) l.app.conf.BROKER_CONNECTION_RETRY = p l.blueprint.restart(l) l.event_dispatcher = Mock() callback = self._get_on_message(l) callback(m.decode(), m) finally: l.timer.stop() l.timer.join() in_hold = l.timer.queue[0] self.assertEqual(len(in_hold), 3) eta, priority, entry = in_hold task = entry.args[0] self.assertIsInstance(task, Request) self.assertEqual(task.name, self.foo_task.name) self.assertEqual(task.execute(), 2 * 4 * 8) with self.assertRaises(Empty): self.buffer.get_nowait() def test_reset_pidbox_node(self): l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) con = find_step(l, consumer.Control).box con.node = Mock() chan = con.node.channel = Mock() l.connection = Mock() chan.close.side_effect = socket.error('foo') l.connection_errors = (socket.error, ) con.reset() chan.close.assert_called_with() def test_reset_pidbox_node_green(self): from celery.worker.pidbox import gPidbox pool = Mock() pool.is_green = True l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool, app=self.app) con = find_step(l, consumer.Control) self.assertIsInstance(con.box, gPidbox) con.start(l) l.pool.spawn_n.assert_called_with( con.box.loop, l, ) def test__green_pidbox_node(self): pool = Mock() pool.is_green = True l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool, app=self.app) l.node = Mock() controller = find_step(l, consumer.Control) class BConsumer(Mock): def __enter__(self): self.consume() return self def __exit__(self, *exc_info): self.cancel() controller.box.node.listen = BConsumer() connections = [] class Connection(object): calls = 0 def __init__(self, obj): connections.append(self) self.obj = obj self.default_channel = self.channel() self.closed = False def __enter__(self): return self def __exit__(self, *exc_info): self.close() def channel(self): return Mock() def as_uri(self): return 'dummy://' def drain_events(self, **kwargs): if not self.calls: self.calls += 1 raise socket.timeout() self.obj.connection = None controller.box._node_shutdown.set() def close(self): self.closed = True l.connection = Mock() l.connect = lambda: Connection(obj=l) controller = find_step(l, consumer.Control) controller.box.loop(l) self.assertTrue(controller.box.node.listen.called) self.assertTrue(controller.box.consumer) controller.box.consumer.consume.assert_called_with() self.assertIsNone(l.connection) self.assertTrue(connections[0].closed) @patch('kombu.connection.Connection._establish_connection') @patch('kombu.utils.sleep') def test_connect_errback(self, sleep, connect): l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) from kombu.transport.memory import Transport Transport.connection_errors = (ChannelError, ) def effect(): if connect.call_count > 1: return raise ChannelError('error') connect.side_effect = effect l.connect() connect.assert_called_with() def test_stop_pidbox_node(self): l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app) cont = find_step(l, consumer.Control) cont._node_stopped = Event() cont._node_shutdown = Event() cont._node_stopped.set() cont.stop(l) def test_start__loop(self): class _QoS(object): prev = 3 value = 4 def update(self): self.prev = self.value class _Consumer(MyKombuConsumer): iterations = 0 def reset_connection(self): if self.iterations >= 1: raise KeyError('foo') init_callback = Mock() l = _Consumer(self.buffer.put, timer=self.timer, init_callback=init_callback, app=self.app) l.task_consumer = Mock() l.broadcast_consumer = Mock() l.qos = _QoS() l.connection = Connection() l.iterations = 0 def raises_KeyError(*args, **kwargs): l.iterations += 1 if l.qos.prev != l.qos.value: l.qos.update() if l.iterations >= 2: raise KeyError('foo') l.loop = raises_KeyError with self.assertRaises(KeyError): l.start() self.assertEqual(l.iterations, 2) self.assertEqual(l.qos.prev, l.qos.value) init_callback.reset_mock() l = _Consumer(self.buffer.put, timer=self.timer, app=self.app, send_events=False, init_callback=init_callback) l.qos = _QoS() l.task_consumer = Mock() l.broadcast_consumer = Mock() l.connection = Connection() l.loop = Mock(side_effect=socket.error('foo')) with self.assertRaises(socket.error): l.start() self.assertTrue(l.loop.call_count) def test_reset_connection_with_no_node(self): l = Consumer(self.buffer.put, timer=self.timer, app=self.app) l.steps.pop() self.assertEqual(None, l.pool) l.blueprint.start(l) class test_WorkController(AppCase): def setup(self): self.worker = self.create_worker() from celery import worker self._logger = worker.logger self._comp_logger = components.logger self.logger = worker.logger = Mock() self.comp_logger = components.logger = Mock() @self.app.task(shared=False) def foo_task(x, y, z): return x * y * z self.foo_task = foo_task def teardown(self): from celery import worker worker.logger = self._logger components.logger = self._comp_logger def create_worker(self, **kw): worker = self.app.WorkController(concurrency=1, loglevel=0, **kw) worker.blueprint.shutdown_complete.set() return worker def test_on_consumer_ready(self): self.worker.on_consumer_ready(Mock()) def test_setup_queues_worker_direct(self): self.app.conf.CELERY_WORKER_DIRECT = True self.app.amqp.__dict__['queues'] = Mock() self.worker.setup_queues({}) self.app.amqp.queues.select_add.assert_called_with( worker_direct(self.worker.hostname), ) def test_send_worker_shutdown(self): with patch('celery.signals.worker_shutdown') as ws: self.worker._send_worker_shutdown() ws.send.assert_called_with(sender=self.worker) def test_process_shutdown_on_worker_shutdown(self): from celery.concurrency.prefork import process_destructor from celery.concurrency.asynpool import Worker with patch('celery.signals.worker_process_shutdown') as ws: Worker._make_shortcuts = Mock() with patch('os._exit') as _exit: worker = Worker(None, None, on_exit=process_destructor) worker._do_exit(22, 3.1415926) ws.send.assert_called_with( sender=None, pid=22, exitcode=3.1415926, ) _exit.assert_called_with(3.1415926) def test_process_task_revoked_release_semaphore(self): self.worker._quick_release = Mock() req = Mock() req.execute_using_pool.side_effect = TaskRevokedError self.worker._process_task(req) self.worker._quick_release.assert_called_with() delattr(self.worker, '_quick_release') self.worker._process_task(req) def test_shutdown_no_blueprint(self): self.worker.blueprint = None self.worker._shutdown() @patch('celery.platforms.create_pidlock') def test_use_pidfile(self, create_pidlock): create_pidlock.return_value = Mock() worker = self.create_worker(pidfile='pidfilelockfilepid') worker.steps = [] worker.start() self.assertTrue(create_pidlock.called) worker.stop() self.assertTrue(worker.pidlock.release.called) @patch('celery.platforms.signals') @patch('celery.platforms.set_mp_process_title') def test_process_initializer(self, set_mp_process_title, _signals): with restore_logging(): from celery import signals from celery._state import _tls from celery.concurrency.prefork import ( process_initializer, WORKER_SIGRESET, WORKER_SIGIGNORE, ) def on_worker_process_init(**kwargs): on_worker_process_init.called = True on_worker_process_init.called = False signals.worker_process_init.connect(on_worker_process_init) def Loader(*args, **kwargs): loader = Mock(*args, **kwargs) loader.conf = {} loader.override_backends = {} return loader with self.Celery(loader=Loader) as app: app.conf = AttributeDict(DEFAULTS) process_initializer(app, 'awesome.worker.com') _signals.ignore.assert_any_call(*WORKER_SIGIGNORE) _signals.reset.assert_any_call(*WORKER_SIGRESET) self.assertTrue(app.loader.init_worker.call_count) self.assertTrue(on_worker_process_init.called) self.assertIs(_tls.current_app, app) set_mp_process_title.assert_called_with( 'celeryd', hostname='awesome.worker.com', ) with patch('celery.app.trace.setup_worker_optimizations') as S: os.environ['FORKED_BY_MULTIPROCESSING'] = "1" try: process_initializer(app, 'luke.worker.com') S.assert_called_with(app) finally: os.environ.pop('FORKED_BY_MULTIPROCESSING', None) def test_attrs(self): worker = self.worker self.assertIsNotNone(worker.timer) self.assertIsInstance(worker.timer, Timer) self.assertIsNotNone(worker.pool) self.assertIsNotNone(worker.consumer) self.assertTrue(worker.steps) def test_with_embedded_beat(self): worker = self.app.WorkController(concurrency=1, loglevel=0, beat=True) self.assertTrue(worker.beat) self.assertIn(worker.beat, [w.obj for w in worker.steps]) def test_with_autoscaler(self): worker = self.create_worker( autoscale=[10, 3], send_events=False, timer_cls='celery.utils.timer2.Timer', ) self.assertTrue(worker.autoscaler) def test_dont_stop_or_terminate(self): worker = self.app.WorkController(concurrency=1, loglevel=0) worker.stop() self.assertNotEqual(worker.blueprint.state, CLOSE) worker.terminate() self.assertNotEqual(worker.blueprint.state, CLOSE) sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False try: worker.blueprint.state = RUN worker.stop(in_sighandler=True) self.assertNotEqual(worker.blueprint.state, CLOSE) worker.terminate(in_sighandler=True) self.assertNotEqual(worker.blueprint.state, CLOSE) finally: worker.pool.signal_safe = sigsafe def test_on_timer_error(self): worker = self.app.WorkController(concurrency=1, loglevel=0) try: raise KeyError('foo') except KeyError as exc: components.Timer(worker).on_timer_error(exc) msg, args = self.comp_logger.error.call_args[0] self.assertIn('KeyError', msg % args) def test_on_timer_tick(self): worker = self.app.WorkController(concurrency=1, loglevel=10) components.Timer(worker).on_timer_tick(30.0) xargs = self.comp_logger.debug.call_args[0] fmt, arg = xargs[0], xargs[1] self.assertEqual(30.0, arg) self.assertIn('Next eta %s secs', fmt) def test_process_task(self): worker = self.worker worker.pool = Mock() backend = Mock() m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10], kwargs={}) task = Request(m.decode(), message=m, app=self.app) worker._process_task(task) self.assertEqual(worker.pool.apply_async.call_count, 1) worker.pool.stop() def test_process_task_raise_base(self): worker = self.worker worker.pool = Mock() worker.pool.apply_async.side_effect = KeyboardInterrupt('Ctrl+C') backend = Mock() m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10], kwargs={}) task = Request(m.decode(), message=m, app=self.app) worker.steps = [] worker.blueprint.state = RUN with self.assertRaises(KeyboardInterrupt): worker._process_task(task) def test_process_task_raise_SystemTerminate(self): worker = self.worker worker.pool = Mock() worker.pool.apply_async.side_effect = SystemTerminate() backend = Mock() m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10], kwargs={}) task = Request(m.decode(), message=m, app=self.app) worker.steps = [] worker.blueprint.state = RUN with self.assertRaises(SystemExit): worker._process_task(task) def test_process_task_raise_regular(self): worker = self.worker worker.pool = Mock() worker.pool.apply_async.side_effect = KeyError('some exception') backend = Mock() m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10], kwargs={}) task = Request(m.decode(), message=m, app=self.app) worker._process_task(task) worker.pool.stop() def test_start_catches_base_exceptions(self): worker1 = self.create_worker() worker1.blueprint.state = RUN stc = MockStep() stc.start.side_effect = SystemTerminate() worker1.steps = [stc] worker1.start() stc.start.assert_called_with(worker1) self.assertTrue(stc.terminate.call_count) worker2 = self.create_worker() worker2.blueprint.state = RUN sec = MockStep() sec.start.side_effect = SystemExit() sec.terminate = None worker2.steps = [sec] worker2.start() self.assertTrue(sec.stop.call_count) def test_state_db(self): from celery.worker import state Persistent = state.Persistent state.Persistent = Mock() try: worker = self.create_worker(state_db='statefilename') self.assertTrue(worker._persistence) finally: state.Persistent = Persistent def test_process_task_sem(self): worker = self.worker worker._quick_acquire = Mock() req = Mock() worker._process_task_sem(req) worker._quick_acquire.assert_called_with(worker._process_task, req) def test_signal_consumer_close(self): worker = self.worker worker.consumer = Mock() worker.signal_consumer_close() worker.consumer.close.assert_called_with() worker.consumer.close.side_effect = AttributeError() worker.signal_consumer_close() def test_start__stop(self): worker = self.worker worker.blueprint.shutdown_complete.set() worker.steps = [MockStep(StartStopStep(self)) for _ in range(4)] worker.blueprint.state = RUN worker.blueprint.started = 4 for w in worker.steps: w.start = Mock() w.close = Mock() w.stop = Mock() worker.start() for w in worker.steps: self.assertTrue(w.start.call_count) worker.consumer = Mock() worker.stop() for stopstep in worker.steps: self.assertTrue(stopstep.close.call_count) self.assertTrue(stopstep.stop.call_count) # Doesn't close pool if no pool. worker.start() worker.pool = None worker.stop() # test that stop of None is not attempted worker.steps[-1] = None worker.start() worker.stop() def test_step_raises(self): worker = self.worker step = Mock() worker.steps = [step] step.start.side_effect = TypeError() worker.stop = Mock() worker.start() worker.stop.assert_called_with() def test_state(self): self.assertTrue(self.worker.state) def test_start__terminate(self): worker = self.worker worker.blueprint.shutdown_complete.set() worker.blueprint.started = 5 worker.blueprint.state = RUN worker.steps = [MockStep() for _ in range(5)] worker.start() for w in worker.steps[:3]: self.assertTrue(w.start.call_count) self.assertTrue(worker.blueprint.started, len(worker.steps)) self.assertEqual(worker.blueprint.state, RUN) worker.terminate() for step in worker.steps: self.assertTrue(step.terminate.call_count) def test_Queues_pool_no_sem(self): w = Mock() w.pool_cls.uses_semaphore = False components.Queues(w).create(w) self.assertIs(w.process_task, w._process_task) def test_Hub_crate(self): w = Mock() x = components.Hub(w) x.create(w) self.assertTrue(w.timer.max_interval) def test_Pool_crate_threaded(self): w = Mock() w._conninfo.connection_errors = w._conninfo.channel_errors = () w.pool_cls = Mock() w.use_eventloop = False pool = components.Pool(w) pool.create(w) def test_Pool_create(self): from kombu.async.semaphore import LaxBoundedSemaphore w = Mock() w._conninfo.connection_errors = w._conninfo.channel_errors = () w.hub = Mock() PoolImp = Mock() poolimp = PoolImp.return_value = Mock() poolimp._pool = [Mock(), Mock()] poolimp._cache = {} poolimp._fileno_to_inq = {} poolimp._fileno_to_outq = {} from celery.concurrency.prefork import TaskPool as _TaskPool class MockTaskPool(_TaskPool): Pool = PoolImp @property def timers(self): return {Mock(): 30} w.pool_cls = MockTaskPool w.use_eventloop = True w.consumer.restart_count = -1 pool = components.Pool(w) pool.create(w) pool.register_with_event_loop(w, w.hub) self.assertIsInstance(w.semaphore, LaxBoundedSemaphore) P = w.pool P.start()