test_state.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. from __future__ import absolute_import
  2. import pickle
  3. from mock import Mock, patch
  4. from time import time
  5. from celery.datastructures import LimitedSet
  6. from celery.exceptions import SystemTerminate
  7. from celery.worker import state
  8. from celery.tests.case import AppCase
  9. class StateResetCase(AppCase):
  10. def setup(self):
  11. self.reset_state()
  12. def teardown(self):
  13. self.reset_state()
  14. def reset_state(self):
  15. state.active_requests.clear()
  16. state.revoked.clear()
  17. state.total_count.clear()
  18. class MockShelve(dict):
  19. filename = None
  20. in_sync = False
  21. closed = False
  22. def open(self, filename, **kwargs):
  23. self.filename = filename
  24. return self
  25. def sync(self):
  26. self.in_sync = True
  27. def close(self):
  28. self.closed = True
  29. class MyPersistent(state.Persistent):
  30. storage = MockShelve()
  31. class test_maybe_shutdown(AppCase):
  32. def teardown(self):
  33. state.should_stop = False
  34. state.should_terminate = False
  35. def test_should_stop(self):
  36. state.should_stop = True
  37. with self.assertRaises(SystemExit):
  38. state.maybe_shutdown()
  39. def test_should_terminate(self):
  40. state.should_terminate = True
  41. with self.assertRaises(SystemTerminate):
  42. state.maybe_shutdown()
  43. class test_Persistent(StateResetCase):
  44. def setup(self):
  45. self.reset_state()
  46. self.p = MyPersistent(state, filename='celery-state')
  47. def test_close_twice(self):
  48. self.p._is_open = False
  49. self.p.close()
  50. def test_constructor(self):
  51. self.assertDictEqual(self.p.db, {})
  52. self.assertEqual(self.p.db.filename, self.p.filename)
  53. def test_save(self):
  54. self.p.db['foo'] = 'bar'
  55. self.p.save()
  56. self.assertTrue(self.p.db.in_sync)
  57. self.assertTrue(self.p.db.closed)
  58. def add_revoked(self, *ids):
  59. for id in ids:
  60. self.p.db.setdefault('revoked', LimitedSet()).add(id)
  61. def test_merge(self, data=['foo', 'bar', 'baz']):
  62. self.add_revoked(*data)
  63. self.p.merge()
  64. for item in data:
  65. self.assertIn(item, state.revoked)
  66. def test_merge_dict(self):
  67. self.p.clock = Mock()
  68. self.p.clock.adjust.return_value = 626
  69. d = {'revoked': {'abc': time()}, 'clock': 313}
  70. self.p._merge_with(d)
  71. self.p.clock.adjust.assert_called_with(313)
  72. self.assertEqual(d['clock'], 626)
  73. self.assertIn('abc', state.revoked)
  74. def test_sync_clock_and_purge(self):
  75. passthrough = Mock()
  76. passthrough.side_effect = lambda x: x
  77. with patch('celery.worker.state.revoked') as revoked:
  78. d = {'clock': 0}
  79. self.p.clock = Mock()
  80. self.p.clock.forward.return_value = 627
  81. self.p._dumps = passthrough
  82. self.p.compress = passthrough
  83. self.p._sync_with(d)
  84. revoked.purge.assert_called_with()
  85. self.assertEqual(d['clock'], 627)
  86. self.assertNotIn('revoked', d)
  87. self.assertIs(d['zrevoked'], revoked)
  88. def test_sync(self, data1=['foo', 'bar', 'baz'],
  89. data2=['baz', 'ini', 'koz']):
  90. self.add_revoked(*data1)
  91. for item in data2:
  92. state.revoked.add(item)
  93. self.p.sync()
  94. self.assertTrue(self.p.db['zrevoked'])
  95. pickled = self.p.decompress(self.p.db['zrevoked'])
  96. self.assertTrue(pickled)
  97. saved = pickle.loads(pickled)
  98. for item in data2:
  99. self.assertIn(item, saved)
  100. class SimpleReq(object):
  101. def __init__(self, name):
  102. self.name = name
  103. class test_state(StateResetCase):
  104. def test_accepted(self, requests=[SimpleReq('foo'),
  105. SimpleReq('bar'),
  106. SimpleReq('baz'),
  107. SimpleReq('baz')]):
  108. for request in requests:
  109. state.task_accepted(request)
  110. for req in requests:
  111. self.assertIn(req, state.active_requests)
  112. self.assertEqual(state.total_count['foo'], 1)
  113. self.assertEqual(state.total_count['bar'], 1)
  114. self.assertEqual(state.total_count['baz'], 2)
  115. def test_ready(self, requests=[SimpleReq('foo'),
  116. SimpleReq('bar')]):
  117. for request in requests:
  118. state.task_accepted(request)
  119. self.assertEqual(len(state.active_requests), 2)
  120. for request in requests:
  121. state.task_ready(request)
  122. self.assertEqual(len(state.active_requests), 0)