test_autoscale.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. from __future__ import absolute_import
  2. import sys
  3. from mock import Mock, patch
  4. from celery.concurrency.base import BasePool
  5. from celery.five import monotonic
  6. from celery.worker import state
  7. from celery.worker import autoscale
  8. from celery.tests.case import AppCase, sleepdeprived
  9. class Object(object):
  10. pass
  11. class MockPool(BasePool):
  12. shrink_raises_exception = False
  13. shrink_raises_ValueError = False
  14. def __init__(self, *args, **kwargs):
  15. super(MockPool, self).__init__(*args, **kwargs)
  16. self._pool = Object()
  17. self._pool._processes = self.limit
  18. def grow(self, n=1):
  19. self._pool._processes += n
  20. def shrink(self, n=1):
  21. if self.shrink_raises_exception:
  22. raise KeyError('foo')
  23. if self.shrink_raises_ValueError:
  24. raise ValueError('foo')
  25. self._pool._processes -= n
  26. @property
  27. def num_processes(self):
  28. return self._pool._processes
  29. class test_WorkerComponent(AppCase):
  30. def test_register_with_event_loop(self):
  31. parent = Mock(name='parent')
  32. parent.autoscale = True
  33. parent.consumer.on_task_message = set()
  34. w = autoscale.WorkerComponent(parent)
  35. self.assertIsNone(parent.autoscaler)
  36. self.assertTrue(w.enabled)
  37. hub = Mock(name='hub')
  38. w.create(parent)
  39. w.register_with_event_loop(parent, hub)
  40. self.assertIn(
  41. parent.autoscaler.maybe_scale,
  42. parent.consumer.on_task_message,
  43. )
  44. hub.call_repeatedly.assert_called_with(
  45. parent.autoscaler.keepalive, parent.autoscaler.maybe_scale,
  46. )
  47. parent.hub = hub
  48. hub.on_init = []
  49. w.instantiate = Mock()
  50. w.register_with_event_loop(parent, Mock(name='loop'))
  51. self.assertTrue(parent.consumer.on_task_message)
  52. class test_Autoscaler(AppCase):
  53. def setup(self):
  54. self.pool = MockPool(3)
  55. def test_stop(self):
  56. class Scaler(autoscale.Autoscaler):
  57. alive = True
  58. joined = False
  59. def is_alive(self):
  60. return self.alive
  61. def join(self, timeout=None):
  62. self.joined = True
  63. worker = Mock(name='worker')
  64. x = Scaler(self.pool, 10, 3, worker=worker)
  65. x._is_stopped.set()
  66. x.stop()
  67. self.assertTrue(x.joined)
  68. x.joined = False
  69. x.alive = False
  70. x.stop()
  71. self.assertFalse(x.joined)
  72. @sleepdeprived(autoscale)
  73. def test_body(self):
  74. worker = Mock(name='worker')
  75. x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
  76. x.body()
  77. self.assertEqual(x.pool.num_processes, 3)
  78. for i in range(20):
  79. state.reserved_requests.add(i)
  80. x.body()
  81. x.body()
  82. self.assertEqual(x.pool.num_processes, 10)
  83. self.assertTrue(worker.consumer._update_prefetch_count.called)
  84. state.reserved_requests.clear()
  85. x.body()
  86. self.assertEqual(x.pool.num_processes, 10)
  87. x._last_action = monotonic() - 10000
  88. x.body()
  89. self.assertEqual(x.pool.num_processes, 3)
  90. self.assertTrue(worker.consumer._update_prefetch_count.called)
  91. def test_run(self):
  92. class Scaler(autoscale.Autoscaler):
  93. scale_called = False
  94. def body(self):
  95. self.scale_called = True
  96. self._is_shutdown.set()
  97. worker = Mock(name='worker')
  98. x = Scaler(self.pool, 10, 3, worker=worker)
  99. x.run()
  100. self.assertTrue(x._is_shutdown.isSet())
  101. self.assertTrue(x._is_stopped.isSet())
  102. self.assertTrue(x.scale_called)
  103. def test_shrink_raises_exception(self):
  104. worker = Mock(name='worker')
  105. x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
  106. x.scale_up(3)
  107. x._last_action = monotonic() - 10000
  108. x.pool.shrink_raises_exception = True
  109. x.scale_down(1)
  110. @patch('celery.worker.autoscale.debug')
  111. def test_shrink_raises_ValueError(self, debug):
  112. worker = Mock(name='worker')
  113. x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
  114. x.scale_up(3)
  115. x._last_action = monotonic() - 10000
  116. x.pool.shrink_raises_ValueError = True
  117. x.scale_down(1)
  118. self.assertTrue(debug.call_count)
  119. def test_update_and_force(self):
  120. worker = Mock(name='worker')
  121. x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
  122. self.assertEqual(x.processes, 3)
  123. x.force_scale_up(5)
  124. self.assertEqual(x.processes, 8)
  125. x.update(5, None)
  126. self.assertEqual(x.processes, 5)
  127. x.force_scale_down(3)
  128. self.assertEqual(x.processes, 2)
  129. x.update(3, None)
  130. self.assertEqual(x.processes, 3)
  131. x.force_scale_down(1000)
  132. self.assertEqual(x.min_concurrency, 0)
  133. self.assertEqual(x.processes, 0)
  134. x.force_scale_up(1000)
  135. x.min_concurrency = 1
  136. x.force_scale_down(1)
  137. x.update(max=300, min=10)
  138. x.update(max=300, min=2)
  139. x.update(max=None, min=None)
  140. def test_info(self):
  141. worker = Mock(name='worker')
  142. x = autoscale.Autoscaler(self.pool, 10, 3, worker=worker)
  143. info = x.info()
  144. self.assertEqual(info['max'], 10)
  145. self.assertEqual(info['min'], 3)
  146. self.assertEqual(info['current'], 3)
  147. @patch('os._exit')
  148. def test_thread_crash(self, _exit):
  149. class _Autoscaler(autoscale.Autoscaler):
  150. def body(self):
  151. self._is_shutdown.set()
  152. raise OSError('foo')
  153. worker = Mock(name='worker')
  154. x = _Autoscaler(self.pool, 10, 3, worker=worker)
  155. stderr = Mock()
  156. p, sys.stderr = sys.stderr, stderr
  157. try:
  158. x.run()
  159. finally:
  160. sys.stderr = p
  161. _exit.assert_called_with(1)
  162. self.assertTrue(stderr.write.call_count)