from __future__ import absolute_import import pickle from mock import Mock, patch from time import time from celery.datastructures import LimitedSet from celery.exceptions import SystemTerminate from celery.worker import state from celery.tests.case import AppCase class StateResetCase(AppCase): def setup(self): self.reset_state() def teardown(self): self.reset_state() def reset_state(self): state.active_requests.clear() state.revoked.clear() state.total_count.clear() class MockShelve(dict): filename = None in_sync = False closed = False def open(self, filename, **kwargs): self.filename = filename return self def sync(self): self.in_sync = True def close(self): self.closed = True class MyPersistent(state.Persistent): storage = MockShelve() class test_maybe_shutdown(AppCase): def teardown(self): state.should_stop = False state.should_terminate = False def test_should_stop(self): state.should_stop = True with self.assertRaises(SystemExit): state.maybe_shutdown() def test_should_terminate(self): state.should_terminate = True with self.assertRaises(SystemTerminate): state.maybe_shutdown() class test_Persistent(StateResetCase): def setup(self): self.reset_state() self.p = MyPersistent(state, filename='celery-state') def test_close_twice(self): self.p._is_open = False self.p.close() def test_constructor(self): self.assertDictEqual(self.p.db, {}) self.assertEqual(self.p.db.filename, self.p.filename) def test_save(self): self.p.db['foo'] = 'bar' self.p.save() self.assertTrue(self.p.db.in_sync) self.assertTrue(self.p.db.closed) def add_revoked(self, *ids): for id in ids: self.p.db.setdefault('revoked', LimitedSet()).add(id) def test_merge(self, data=['foo', 'bar', 'baz']): self.add_revoked(*data) self.p.merge() for item in data: self.assertIn(item, state.revoked) def test_merge_dict(self): self.p.clock = Mock() self.p.clock.adjust.return_value = 626 d = {'revoked': {'abc': time()}, 'clock': 313} self.p._merge_with(d) self.p.clock.adjust.assert_called_with(313) self.assertEqual(d['clock'], 626) self.assertIn('abc', state.revoked) def test_sync_clock_and_purge(self): passthrough = Mock() passthrough.side_effect = lambda x: x with patch('celery.worker.state.revoked') as revoked: d = {'clock': 0} self.p.clock = Mock() self.p.clock.forward.return_value = 627 self.p._dumps = passthrough self.p.compress = passthrough self.p._sync_with(d) revoked.purge.assert_called_with() self.assertEqual(d['clock'], 627) self.assertNotIn('revoked', d) self.assertIs(d['zrevoked'], revoked) def test_sync(self, data1=['foo', 'bar', 'baz'], data2=['baz', 'ini', 'koz']): self.add_revoked(*data1) for item in data2: state.revoked.add(item) self.p.sync() self.assertTrue(self.p.db['zrevoked']) pickled = self.p.decompress(self.p.db['zrevoked']) self.assertTrue(pickled) saved = pickle.loads(pickled) for item in data2: self.assertIn(item, saved) class SimpleReq(object): def __init__(self, name): self.name = name class test_state(StateResetCase): def test_accepted(self, requests=[SimpleReq('foo'), SimpleReq('bar'), SimpleReq('baz'), SimpleReq('baz')]): for request in requests: state.task_accepted(request) for req in requests: self.assertIn(req, state.active_requests) self.assertEqual(state.total_count['foo'], 1) self.assertEqual(state.total_count['bar'], 1) self.assertEqual(state.total_count['baz'], 2) def test_ready(self, requests=[SimpleReq('foo'), SimpleReq('bar')]): for request in requests: state.task_accepted(request) self.assertEqual(len(state.active_requests), 2) for request in requests: state.task_ready(request) self.assertEqual(len(state.active_requests), 0)