test_prefork.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. from __future__ import absolute_import
  2. import errno
  3. import socket
  4. import time
  5. from itertools import cycle
  6. from mock import Mock, call, patch
  7. from nose import SkipTest
  8. from celery.five import items, range
  9. from celery.utils.functional import noop
  10. from celery.tests.case import AppCase
  11. try:
  12. from celery.concurrency import prefork as mp
  13. from celery.concurrency import asynpool
  14. except ImportError:
  15. class _mp(object):
  16. RUN = 0x1
  17. class TaskPool(object):
  18. _pool = Mock()
  19. def __init__(self, *args, **kwargs):
  20. pass
  21. def start(self):
  22. pass
  23. def stop(self):
  24. pass
  25. def apply_async(self, *args, **kwargs):
  26. pass
  27. mp = _mp() # noqa
  28. asynpool = None # noqa
  29. class Object(object): # for writeable attributes.
  30. def __init__(self, **kwargs):
  31. [setattr(self, k, v) for k, v in items(kwargs)]
  32. class MockResult(object):
  33. def __init__(self, value, pid):
  34. self.value = value
  35. self.pid = pid
  36. def worker_pids(self):
  37. return [self.pid]
  38. def get(self):
  39. return self.value
  40. class MockPool(object):
  41. started = False
  42. closed = False
  43. joined = False
  44. terminated = False
  45. _state = None
  46. def __init__(self, *args, **kwargs):
  47. self.started = True
  48. self._timeout_handler = Mock()
  49. self._result_handler = Mock()
  50. self.maintain_pool = Mock()
  51. self._state = mp.RUN
  52. self._processes = kwargs.get('processes')
  53. self._pool = [Object(pid=i, inqW_fd=1, outqR_fd=2)
  54. for i in range(self._processes)]
  55. self._current_proc = cycle(range(self._processes))
  56. def close(self):
  57. self.closed = True
  58. self._state = 'CLOSE'
  59. def join(self):
  60. self.joined = True
  61. def terminate(self):
  62. self.terminated = True
  63. def terminate_job(self, *args, **kwargs):
  64. pass
  65. def restart(self, *args, **kwargs):
  66. pass
  67. def handle_result_event(self, *args, **kwargs):
  68. pass
  69. def flush(self):
  70. pass
  71. def grow(self, n=1):
  72. self._processes += n
  73. def shrink(self, n=1):
  74. self._processes -= n
  75. def apply_async(self, *args, **kwargs):
  76. pass
  77. def register_with_event_loop(self, loop):
  78. pass
  79. class ExeMockPool(MockPool):
  80. def apply_async(self, target, args=(), kwargs={}, callback=noop):
  81. from threading import Timer
  82. res = target(*args, **kwargs)
  83. Timer(0.1, callback, (res, )).start()
  84. return MockResult(res, next(self._current_proc))
  85. class TaskPool(mp.TaskPool):
  86. Pool = BlockingPool = MockPool
  87. class ExeMockTaskPool(mp.TaskPool):
  88. Pool = BlockingPool = ExeMockPool
  89. class PoolCase(AppCase):
  90. def setup(self):
  91. try:
  92. import multiprocessing # noqa
  93. except ImportError:
  94. raise SkipTest('multiprocessing not supported')
  95. class test_AsynPool(PoolCase):
  96. def test_gen_not_started(self):
  97. def gen():
  98. yield 1
  99. yield 2
  100. g = gen()
  101. self.assertTrue(asynpool.gen_not_started(g))
  102. next(g)
  103. self.assertFalse(asynpool.gen_not_started(g))
  104. list(g)
  105. self.assertFalse(asynpool.gen_not_started(g))
  106. def test_select(self):
  107. ebadf = socket.error()
  108. ebadf.errno = errno.EBADF
  109. with patch('select.select') as select:
  110. select.return_value = ([3], [], [])
  111. self.assertEqual(
  112. asynpool._select(set([3])),
  113. ([3], [], 0),
  114. )
  115. select.return_value = ([], [], [3])
  116. self.assertEqual(
  117. asynpool._select(set([3]), None, set([3])),
  118. ([3], [], 0),
  119. )
  120. eintr = socket.error()
  121. eintr.errno = errno.EINTR
  122. select.side_effect = eintr
  123. readers = set([3])
  124. self.assertEqual(asynpool._select(readers), ([], [], 1))
  125. self.assertIn(3, readers)
  126. with patch('select.select') as select:
  127. select.side_effect = ebadf
  128. readers = set([3])
  129. self.assertEqual(asynpool._select(readers), ([], [], 1))
  130. select.assert_has_calls([call([3], [], [], 0)])
  131. self.assertNotIn(3, readers)
  132. with patch('select.select') as select:
  133. select.side_effect = MemoryError()
  134. with self.assertRaises(MemoryError):
  135. asynpool._select(set([1]))
  136. with patch('select.select') as select:
  137. def se(*args):
  138. select.side_effect = MemoryError()
  139. raise ebadf
  140. select.side_effect = se
  141. with self.assertRaises(MemoryError):
  142. asynpool._select(set([3]))
  143. with patch('select.select') as select:
  144. def se2(*args):
  145. select.side_effect = socket.error()
  146. select.side_effect.errno = 1321
  147. raise ebadf
  148. select.side_effect = se2
  149. with self.assertRaises(socket.error):
  150. asynpool._select(set([3]))
  151. with patch('select.select') as select:
  152. select.side_effect = socket.error()
  153. select.side_effect.errno = 34134
  154. with self.assertRaises(socket.error):
  155. asynpool._select(set([3]))
  156. def test_promise(self):
  157. fun = Mock()
  158. x = asynpool.promise(fun, (1, ), {'foo': 1})
  159. x()
  160. self.assertTrue(x.ready)
  161. fun.assert_called_with(1, foo=1)
  162. def test_Worker(self):
  163. w = asynpool.Worker(Mock(), Mock())
  164. w.on_loop_start(1234)
  165. w.outq.put.assert_called_with((asynpool.WORKER_UP, (1234, )))
  166. class test_ResultHandler(PoolCase):
  167. def test_process_result(self):
  168. x = asynpool.ResultHandler(
  169. Mock(), Mock(), {}, Mock(),
  170. Mock(), Mock(), Mock(), Mock(),
  171. fileno_to_outq={},
  172. on_process_alive=Mock(),
  173. on_job_ready=Mock(),
  174. )
  175. self.assertTrue(x)
  176. hub = Mock(name='hub')
  177. recv = x._recv_message = Mock(name='recv_message')
  178. recv.return_value = iter([])
  179. x.on_state_change = Mock()
  180. x.register_with_event_loop(hub)
  181. proc = x.fileno_to_outq[3] = Mock()
  182. reader = proc.outq._reader
  183. reader.poll.return_value = False
  184. x.handle_event(6) # KeyError
  185. x.handle_event(3)
  186. x._recv_message.assert_called_with(
  187. hub.add_reader, 3, x.on_state_change,
  188. )
  189. class test_TaskPool(PoolCase):
  190. def test_start(self):
  191. pool = TaskPool(10)
  192. pool.start()
  193. self.assertTrue(pool._pool.started)
  194. self.assertTrue(pool._pool._state == asynpool.RUN)
  195. _pool = pool._pool
  196. pool.stop()
  197. self.assertTrue(_pool.closed)
  198. self.assertTrue(_pool.joined)
  199. pool.stop()
  200. pool.start()
  201. _pool = pool._pool
  202. pool.terminate()
  203. pool.terminate()
  204. self.assertTrue(_pool.terminated)
  205. def test_apply_async(self):
  206. pool = TaskPool(10)
  207. pool.start()
  208. pool.apply_async(lambda x: x, (2, ), {})
  209. def test_grow_shrink(self):
  210. pool = TaskPool(10)
  211. pool.start()
  212. self.assertEqual(pool._pool._processes, 10)
  213. pool.grow()
  214. self.assertEqual(pool._pool._processes, 11)
  215. pool.shrink(2)
  216. self.assertEqual(pool._pool._processes, 9)
  217. def test_info(self):
  218. pool = TaskPool(10)
  219. procs = [Object(pid=i) for i in range(pool.limit)]
  220. class _Pool(object):
  221. _pool = procs
  222. _maxtasksperchild = None
  223. timeout = 10
  224. soft_timeout = 5
  225. def human_write_stats(self, *args, **kwargs):
  226. return {}
  227. pool._pool = _Pool()
  228. info = pool.info
  229. self.assertEqual(info['max-concurrency'], pool.limit)
  230. self.assertEqual(info['max-tasks-per-child'], 'N/A')
  231. self.assertEqual(info['timeouts'], (5, 10))
  232. def test_num_processes(self):
  233. pool = TaskPool(7)
  234. pool.start()
  235. self.assertEqual(pool.num_processes, 7)
  236. def test_restart(self):
  237. raise SkipTest('functional test')
  238. def get_pids(pool):
  239. return set([p.pid for p in pool._pool._pool])
  240. tp = self.TaskPool(5)
  241. time.sleep(0.5)
  242. tp.start()
  243. pids = get_pids(tp)
  244. tp.restart()
  245. time.sleep(0.5)
  246. self.assertEqual(pids, get_pids(tp))