from __future__ import absolute_import import errno import socket import time from itertools import cycle from mock import Mock, call, patch from nose import SkipTest from celery.five import items, range from celery.utils.functional import noop from celery.tests.case import AppCase try: from celery.concurrency import prefork as mp from celery.concurrency import asynpool except ImportError: class _mp(object): RUN = 0x1 class TaskPool(object): _pool = Mock() def __init__(self, *args, **kwargs): pass def start(self): pass def stop(self): pass def apply_async(self, *args, **kwargs): pass mp = _mp() # noqa asynpool = None # noqa class Object(object): # for writeable attributes. def __init__(self, **kwargs): [setattr(self, k, v) for k, v in items(kwargs)] class MockResult(object): def __init__(self, value, pid): self.value = value self.pid = pid def worker_pids(self): return [self.pid] def get(self): return self.value class MockPool(object): started = False closed = False joined = False terminated = False _state = None def __init__(self, *args, **kwargs): self.started = True self._timeout_handler = Mock() self._result_handler = Mock() self.maintain_pool = Mock() self._state = mp.RUN self._processes = kwargs.get('processes') self._pool = [Object(pid=i, inqW_fd=1, outqR_fd=2) for i in range(self._processes)] self._current_proc = cycle(range(self._processes)) def close(self): self.closed = True self._state = 'CLOSE' def join(self): self.joined = True def terminate(self): self.terminated = True def terminate_job(self, *args, **kwargs): pass def restart(self, *args, **kwargs): pass def handle_result_event(self, *args, **kwargs): pass def flush(self): pass def grow(self, n=1): self._processes += n def shrink(self, n=1): self._processes -= n def apply_async(self, *args, **kwargs): pass def register_with_event_loop(self, loop): pass class ExeMockPool(MockPool): def apply_async(self, target, args=(), kwargs={}, callback=noop): from threading import Timer res = target(*args, **kwargs) Timer(0.1, callback, (res, )).start() return MockResult(res, next(self._current_proc)) class TaskPool(mp.TaskPool): Pool = BlockingPool = MockPool class ExeMockTaskPool(mp.TaskPool): Pool = BlockingPool = ExeMockPool class PoolCase(AppCase): def setup(self): try: import multiprocessing # noqa except ImportError: raise SkipTest('multiprocessing not supported') class test_AsynPool(PoolCase): def test_gen_not_started(self): def gen(): yield 1 yield 2 g = gen() self.assertTrue(asynpool.gen_not_started(g)) next(g) self.assertFalse(asynpool.gen_not_started(g)) list(g) self.assertFalse(asynpool.gen_not_started(g)) def test_select(self): ebadf = socket.error() ebadf.errno = errno.EBADF with patch('select.select') as select: select.return_value = ([3], [], []) self.assertEqual( asynpool._select(set([3])), ([3], [], 0), ) select.return_value = ([], [], [3]) self.assertEqual( asynpool._select(set([3]), None, set([3])), ([3], [], 0), ) eintr = socket.error() eintr.errno = errno.EINTR select.side_effect = eintr readers = set([3]) self.assertEqual(asynpool._select(readers), ([], [], 1)) self.assertIn(3, readers) with patch('select.select') as select: select.side_effect = ebadf readers = set([3]) self.assertEqual(asynpool._select(readers), ([], [], 1)) select.assert_has_calls([call([3], [], [], 0)]) self.assertNotIn(3, readers) with patch('select.select') as select: select.side_effect = MemoryError() with self.assertRaises(MemoryError): asynpool._select(set([1])) with patch('select.select') as select: def se(*args): select.side_effect = MemoryError() raise ebadf select.side_effect = se with self.assertRaises(MemoryError): asynpool._select(set([3])) with patch('select.select') as select: def se2(*args): select.side_effect = socket.error() select.side_effect.errno = 1321 raise ebadf select.side_effect = se2 with self.assertRaises(socket.error): asynpool._select(set([3])) with patch('select.select') as select: select.side_effect = socket.error() select.side_effect.errno = 34134 with self.assertRaises(socket.error): asynpool._select(set([3])) def test_promise(self): fun = Mock() x = asynpool.promise(fun, (1, ), {'foo': 1}) x() self.assertTrue(x.ready) fun.assert_called_with(1, foo=1) def test_Worker(self): w = asynpool.Worker(Mock(), Mock()) w.on_loop_start(1234) w.outq.put.assert_called_with((asynpool.WORKER_UP, (1234, ))) class test_ResultHandler(PoolCase): def test_process_result(self): x = asynpool.ResultHandler( Mock(), Mock(), {}, Mock(), Mock(), Mock(), Mock(), Mock(), fileno_to_outq={}, on_process_alive=Mock(), on_job_ready=Mock(), ) self.assertTrue(x) hub = Mock(name='hub') recv = x._recv_message = Mock(name='recv_message') recv.return_value = iter([]) x.on_state_change = Mock() x.register_with_event_loop(hub) proc = x.fileno_to_outq[3] = Mock() reader = proc.outq._reader reader.poll.return_value = False x.handle_event(6) # KeyError x.handle_event(3) x._recv_message.assert_called_with( hub.add_reader, 3, x.on_state_change, ) class test_TaskPool(PoolCase): def test_start(self): pool = TaskPool(10) pool.start() self.assertTrue(pool._pool.started) self.assertTrue(pool._pool._state == asynpool.RUN) _pool = pool._pool pool.stop() self.assertTrue(_pool.closed) self.assertTrue(_pool.joined) pool.stop() pool.start() _pool = pool._pool pool.terminate() pool.terminate() self.assertTrue(_pool.terminated) def test_apply_async(self): pool = TaskPool(10) pool.start() pool.apply_async(lambda x: x, (2, ), {}) def test_grow_shrink(self): pool = TaskPool(10) pool.start() self.assertEqual(pool._pool._processes, 10) pool.grow() self.assertEqual(pool._pool._processes, 11) pool.shrink(2) self.assertEqual(pool._pool._processes, 9) def test_info(self): pool = TaskPool(10) procs = [Object(pid=i) for i in range(pool.limit)] class _Pool(object): _pool = procs _maxtasksperchild = None timeout = 10 soft_timeout = 5 def human_write_stats(self, *args, **kwargs): return {} pool._pool = _Pool() info = pool.info self.assertEqual(info['max-concurrency'], pool.limit) self.assertEqual(info['max-tasks-per-child'], 'N/A') self.assertEqual(info['timeouts'], (5, 10)) def test_num_processes(self): pool = TaskPool(7) pool.start() self.assertEqual(pool.num_processes, 7) def test_restart(self): raise SkipTest('functional test') def get_pids(pool): return set([p.pid for p in pool._pool._pool]) tp = self.TaskPool(5) time.sleep(0.5) tp.start() pids = get_pids(tp) tp.restart() time.sleep(0.5) self.assertEqual(pids, get_pids(tp))