from __future__ import absolute_import from celery.worker.heartbeat import Heart from celery.tests.case import AppCase class MockDispatcher(object): heart = None next_iter = 0 def __init__(self): self.sent = [] self.on_enabled = set() self.on_disabled = set() self.enabled = True def send(self, msg, **_fields): self.sent.append(msg) if self.heart: if self.next_iter > 10: self.heart._shutdown.set() self.next_iter += 1 class MockDispatcherRaising(object): def send(self, msg): if msg == 'worker-offline': raise Exception('foo') class MockTimer(object): def call_repeatedly(self, secs, fun, args=(), kwargs={}): class entry(tuple): cancelled = False def cancel(self): self.cancelled = True return entry((secs, fun, args, kwargs)) def cancel(self, entry): entry.cancel() class test_Heart(AppCase): def test_start_stop(self): timer = MockTimer() eventer = MockDispatcher() h = Heart(timer, eventer, interval=1) h.start() self.assertTrue(h.tref) h.stop() self.assertIsNone(h.tref) h.stop() def test_start_when_disabled(self): timer = MockTimer() eventer = MockDispatcher() eventer.enabled = False h = Heart(timer, eventer) h.start() self.assertFalse(h.tref) def test_stop_when_disabled(self): timer = MockTimer() eventer = MockDispatcher() eventer.enabled = False h = Heart(timer, eventer) h.stop()