123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- from __future__ import absolute_import
- import pickle
- from mock import Mock, patch
- from time import time
- from celery.datastructures import LimitedSet
- from celery.exceptions import SystemTerminate
- from celery.worker import state
- from celery.tests.case import AppCase
- class StateResetCase(AppCase):
- def setup(self):
- self.reset_state()
- def teardown(self):
- self.reset_state()
- def reset_state(self):
- state.active_requests.clear()
- state.revoked.clear()
- state.total_count.clear()
- class MockShelve(dict):
- filename = None
- in_sync = False
- closed = False
- def open(self, filename, **kwargs):
- self.filename = filename
- return self
- def sync(self):
- self.in_sync = True
- def close(self):
- self.closed = True
- class MyPersistent(state.Persistent):
- storage = MockShelve()
- class test_maybe_shutdown(AppCase):
- def teardown(self):
- state.should_stop = False
- state.should_terminate = False
- def test_should_stop(self):
- state.should_stop = True
- with self.assertRaises(SystemExit):
- state.maybe_shutdown()
- def test_should_terminate(self):
- state.should_terminate = True
- with self.assertRaises(SystemTerminate):
- state.maybe_shutdown()
- class test_Persistent(StateResetCase):
- def setup(self):
- self.reset_state()
- self.p = MyPersistent(state, filename='celery-state')
- def test_close_twice(self):
- self.p._is_open = False
- self.p.close()
- def test_constructor(self):
- self.assertDictEqual(self.p.db, {})
- self.assertEqual(self.p.db.filename, self.p.filename)
- def test_save(self):
- self.p.db['foo'] = 'bar'
- self.p.save()
- self.assertTrue(self.p.db.in_sync)
- self.assertTrue(self.p.db.closed)
- def add_revoked(self, *ids):
- for id in ids:
- self.p.db.setdefault('revoked', LimitedSet()).add(id)
- def test_merge(self, data=['foo', 'bar', 'baz']):
- self.add_revoked(*data)
- self.p.merge()
- for item in data:
- self.assertIn(item, state.revoked)
- def test_merge_dict(self):
- self.p.clock = Mock()
- self.p.clock.adjust.return_value = 626
- d = {'revoked': {'abc': time()}, 'clock': 313}
- self.p._merge_with(d)
- self.p.clock.adjust.assert_called_with(313)
- self.assertEqual(d['clock'], 626)
- self.assertIn('abc', state.revoked)
- def test_sync_clock_and_purge(self):
- passthrough = Mock()
- passthrough.side_effect = lambda x: x
- with patch('celery.worker.state.revoked') as revoked:
- d = {'clock': 0}
- self.p.clock = Mock()
- self.p.clock.forward.return_value = 627
- self.p._dumps = passthrough
- self.p.compress = passthrough
- self.p._sync_with(d)
- revoked.purge.assert_called_with()
- self.assertEqual(d['clock'], 627)
- self.assertNotIn('revoked', d)
- self.assertIs(d['zrevoked'], revoked)
- def test_sync(self, data1=['foo', 'bar', 'baz'],
- data2=['baz', 'ini', 'koz']):
- self.add_revoked(*data1)
- for item in data2:
- state.revoked.add(item)
- self.p.sync()
- self.assertTrue(self.p.db['zrevoked'])
- pickled = self.p.decompress(self.p.db['zrevoked'])
- self.assertTrue(pickled)
- saved = pickle.loads(pickled)
- for item in data2:
- self.assertIn(item, saved)
- class SimpleReq(object):
- def __init__(self, name):
- self.name = name
- class test_state(StateResetCase):
- def test_accepted(self, requests=[SimpleReq('foo'),
- SimpleReq('bar'),
- SimpleReq('baz'),
- SimpleReq('baz')]):
- for request in requests:
- state.task_accepted(request)
- for req in requests:
- self.assertIn(req, state.active_requests)
- self.assertEqual(state.total_count['foo'], 1)
- self.assertEqual(state.total_count['bar'], 1)
- self.assertEqual(state.total_count['baz'], 2)
- def test_ready(self, requests=[SimpleReq('foo'),
- SimpleReq('bar')]):
- for request in requests:
- state.task_accepted(request)
- self.assertEqual(len(state.active_requests), 2)
- for request in requests:
- state.task_ready(request)
- self.assertEqual(len(state.active_requests), 0)
|