from __future__ import absolute_import import socket from mock import Mock from celery.events import Event from celery.tests.case import AppCase class MockProducer(object): raise_on_publish = False def __init__(self, *args, **kwargs): self.sent = [] def publish(self, msg, *args, **kwargs): if self.raise_on_publish: raise KeyError() self.sent.append(msg) def close(self): pass def has_event(self, kind): for event in self.sent: if event['type'] == kind: return event return False class test_Event(AppCase): def test_constructor(self): event = Event('world war II') self.assertEqual(event['type'], 'world war II') self.assertTrue(event['timestamp']) class test_EventDispatcher(AppCase): def test_redis_uses_fanout_exchange(self): self.app.connection = Mock() conn = self.app.connection.return_value = Mock() conn.transport.driver_type = 'redis' dispatcher = self.app.events.Dispatcher(conn, enabled=False) self.assertEqual(dispatcher.exchange.type, 'fanout') def test_others_use_topic_exchange(self): self.app.connection = Mock() conn = self.app.connection.return_value = Mock() conn.transport.driver_type = 'amqp' dispatcher = self.app.events.Dispatcher(conn, enabled=False) self.assertEqual(dispatcher.exchange.type, 'topic') def test_takes_channel_connection(self): x = self.app.events.Dispatcher(channel=Mock()) self.assertIs(x.connection, x.channel.connection.client) def test_sql_transports_disabled(self): conn = Mock() conn.transport.driver_type = 'sql' x = self.app.events.Dispatcher(connection=conn) self.assertFalse(x.enabled) def test_send(self): producer = MockProducer() producer.connection = self.app.connection() connection = Mock() connection.transport.driver_type = 'amqp' eventer = self.app.events.Dispatcher(connection, enabled=False, buffer_while_offline=False) eventer.producer = producer eventer.enabled = True eventer.send('World War II', ended=True) self.assertTrue(producer.has_event('World War II')) eventer.enabled = False eventer.send('World War III') self.assertFalse(producer.has_event('World War III')) evs = ('Event 1', 'Event 2', 'Event 3') eventer.enabled = True eventer.producer.raise_on_publish = True eventer.buffer_while_offline = False with self.assertRaises(KeyError): eventer.send('Event X') eventer.buffer_while_offline = True for ev in evs: eventer.send(ev) eventer.producer.raise_on_publish = False eventer.flush() for ev in evs: self.assertTrue(producer.has_event(ev)) buf = eventer._outbound_buffer = Mock() buf.popleft.side_effect = IndexError() eventer.flush() def test_enter_exit(self): with self.app.connection() as conn: d = self.app.events.Dispatcher(conn) d.close = Mock() with d as _d: self.assertTrue(_d) d.close.assert_called_with() def test_enable_disable_callbacks(self): on_enable = Mock() on_disable = Mock() with self.app.connection() as conn: with self.app.events.Dispatcher(conn, enabled=False) as d: d.on_enabled.add(on_enable) d.on_disabled.add(on_disable) d.enable() on_enable.assert_called_with() d.disable() on_disable.assert_called_with() def test_enabled_disable(self): connection = self.app.connection() channel = connection.channel() try: dispatcher = self.app.events.Dispatcher(connection, enabled=True) dispatcher2 = self.app.events.Dispatcher(connection, enabled=True, channel=channel) self.assertTrue(dispatcher.enabled) self.assertTrue(dispatcher.producer.channel) self.assertEqual(dispatcher.producer.serializer, self.app.conf.CELERY_EVENT_SERIALIZER) created_channel = dispatcher.producer.channel dispatcher.disable() dispatcher.disable() # Disable with no active producer dispatcher2.disable() self.assertFalse(dispatcher.enabled) self.assertIsNone(dispatcher.producer) self.assertFalse(dispatcher2.channel.closed, 'does not close manually provided channel') dispatcher.enable() self.assertTrue(dispatcher.enabled) self.assertTrue(dispatcher.producer) # XXX test compat attribute self.assertIs(dispatcher.publisher, dispatcher.producer) prev, dispatcher.publisher = dispatcher.producer, 42 try: self.assertEqual(dispatcher.producer, 42) finally: dispatcher.producer = prev finally: channel.close() connection.close() self.assertTrue(created_channel.closed) class test_EventReceiver(AppCase): def test_process(self): message = {'type': 'world-war'} got_event = [False] def my_handler(event): got_event[0] = True connection = Mock() connection.transport_cls = 'memory' r = self.app.events.Receiver( connection, handlers={'world-war': my_handler}, node_id='celery.tests', ) r._receive(message, object()) self.assertTrue(got_event[0]) def test_catch_all_event(self): message = {'type': 'world-war'} got_event = [False] def my_handler(event): got_event[0] = True connection = Mock() connection.transport_cls = 'memory' r = self.app.events.Receiver(connection, node_id='celery.tests') r.handlers['*'] = my_handler r._receive(message, object()) self.assertTrue(got_event[0]) def test_itercapture(self): connection = self.app.connection() try: r = self.app.events.Receiver(connection, node_id='celery.tests') it = r.itercapture(timeout=0.0001, wakeup=False) with self.assertRaises(socket.timeout): next(it) with self.assertRaises(socket.timeout): r.capture(timeout=0.00001) finally: connection.close() def test_event_from_message_localize_disabled(self): r = self.app.events.Receiver(Mock(), node_id='celery.tests') r.adjust_clock = Mock() ts_adjust = Mock() r.event_from_message( {'type': 'worker-online', 'clock': 313}, localize=False, adjust_timestamp=ts_adjust, ) self.assertFalse(ts_adjust.called) r.adjust_clock.assert_called_with(313) def test_itercapture_limit(self): connection = self.app.connection() channel = connection.channel() try: events_received = [0] def handler(event): events_received[0] += 1 producer = self.app.events.Dispatcher( connection, enabled=True, channel=channel, ) r = self.app.events.Receiver( connection, handlers={'*': handler}, node_id='celery.tests', ) evs = ['ev1', 'ev2', 'ev3', 'ev4', 'ev5'] for ev in evs: producer.send(ev) it = r.itercapture(limit=4, wakeup=True) next(it) # skip consumer (see itercapture) list(it) self.assertEqual(events_received[0], 4) finally: channel.close() connection.close() class test_misc(AppCase): def test_State(self): state = self.app.events.State() self.assertDictEqual(dict(state.workers), {}) def test_default_dispatcher(self): with self.app.events.default_dispatcher() as d: self.assertTrue(d) self.assertTrue(d.connection)