test_control.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583
  1. from __future__ import absolute_import
  2. import sys
  3. import socket
  4. from collections import defaultdict
  5. from datetime import datetime, timedelta
  6. from kombu import pidbox
  7. from mock import Mock, patch, call
  8. from celery.datastructures import AttributeDict
  9. from celery.five import Queue as FastQueue
  10. from celery.utils import uuid
  11. from celery.utils.timer2 import Timer
  12. from celery.worker import WorkController as _WC
  13. from celery.worker import consumer
  14. from celery.worker import control
  15. from celery.worker import state as worker_state
  16. from celery.worker.job import Request
  17. from celery.worker.state import revoked
  18. from celery.worker.control import Panel
  19. from celery.worker.pidbox import Pidbox, gPidbox
  20. from celery.tests.case import AppCase
  21. hostname = socket.gethostname()
  22. class WorkController(object):
  23. autoscaler = None
  24. def stats(self):
  25. return {'total': worker_state.total_count}
  26. class Consumer(consumer.Consumer):
  27. def __init__(self, app):
  28. self.app = app
  29. self.buffer = FastQueue()
  30. self.handle_task = self.buffer.put
  31. self.timer = Timer()
  32. self.event_dispatcher = Mock()
  33. self.controller = WorkController()
  34. self.task_consumer = Mock()
  35. self.prefetch_multiplier = 1
  36. self.initial_prefetch_count = 1
  37. from celery.concurrency.base import BasePool
  38. self.pool = BasePool(10)
  39. self.task_buckets = defaultdict(lambda: None)
  40. class test_Pidbox(AppCase):
  41. def test_shutdown(self):
  42. with patch('celery.worker.pidbox.ignore_errors') as eig:
  43. parent = Mock()
  44. pbox = Pidbox(parent)
  45. pbox._close_channel = Mock()
  46. self.assertIs(pbox.c, parent)
  47. pconsumer = pbox.consumer = Mock()
  48. cancel = pconsumer.cancel
  49. pbox.shutdown(parent)
  50. eig.assert_called_with(parent, cancel)
  51. pbox._close_channel.assert_called_with(parent)
  52. class test_Pidbox_green(AppCase):
  53. def test_stop(self):
  54. parent = Mock()
  55. g = gPidbox(parent)
  56. stopped = g._node_stopped = Mock()
  57. shutdown = g._node_shutdown = Mock()
  58. close_chan = g._close_channel = Mock()
  59. g.stop(parent)
  60. shutdown.set.assert_called_with()
  61. stopped.wait.assert_called_with()
  62. close_chan.assert_called_with(parent)
  63. self.assertIsNone(g._node_stopped)
  64. self.assertIsNone(g._node_shutdown)
  65. close_chan.reset()
  66. g.stop(parent)
  67. close_chan.assert_called_with(parent)
  68. def test_resets(self):
  69. parent = Mock()
  70. g = gPidbox(parent)
  71. g._resets = 100
  72. g.reset()
  73. self.assertEqual(g._resets, 101)
  74. def test_loop(self):
  75. parent = Mock()
  76. conn = parent.connect.return_value = self.app.connection()
  77. drain = conn.drain_events = Mock()
  78. g = gPidbox(parent)
  79. parent.connection = Mock()
  80. do_reset = g._do_reset = Mock()
  81. call_count = [0]
  82. def se(*args, **kwargs):
  83. if call_count[0] > 2:
  84. g._node_shutdown.set()
  85. g.reset()
  86. call_count[0] += 1
  87. drain.side_effect = se
  88. g.loop(parent)
  89. self.assertEqual(do_reset.call_count, 4)
  90. class test_ControlPanel(AppCase):
  91. def setup(self):
  92. self.panel = self.create_panel(consumer=Consumer(self.app))
  93. @self.app.task(rate_limit=200, shared=False)
  94. def mytask():
  95. pass
  96. self.mytask = mytask
  97. def create_state(self, **kwargs):
  98. kwargs.setdefault('app', self.app)
  99. kwargs.setdefault('hostname', hostname)
  100. return AttributeDict(kwargs)
  101. def create_panel(self, **kwargs):
  102. return self.app.control.mailbox.Node(hostname=hostname,
  103. state=self.create_state(**kwargs),
  104. handlers=Panel.data)
  105. def test_enable_events(self):
  106. consumer = Consumer(self.app)
  107. panel = self.create_panel(consumer=consumer)
  108. evd = consumer.event_dispatcher
  109. evd.groups = set()
  110. panel.handle('enable_events')
  111. self.assertIn('task', evd.groups)
  112. evd.groups = set(['task'])
  113. self.assertIn('already enabled', panel.handle('enable_events')['ok'])
  114. def test_disable_events(self):
  115. consumer = Consumer(self.app)
  116. panel = self.create_panel(consumer=consumer)
  117. evd = consumer.event_dispatcher
  118. evd.enabled = True
  119. evd.groups = set(['task'])
  120. panel.handle('disable_events')
  121. self.assertNotIn('task', evd.groups)
  122. self.assertIn('already disabled', panel.handle('disable_events')['ok'])
  123. def test_clock(self):
  124. consumer = Consumer(self.app)
  125. panel = self.create_panel(consumer=consumer)
  126. panel.state.app.clock.value = 313
  127. x = panel.handle('clock')
  128. self.assertEqual(x['clock'], 313)
  129. def test_hello(self):
  130. consumer = Consumer(self.app)
  131. panel = self.create_panel(consumer=consumer)
  132. panel.state.app.clock.value = 313
  133. worker_state.revoked.add('revoked1')
  134. try:
  135. x = panel.handle('hello', {'from_node': 'george@vandelay.com'})
  136. self.assertIn('revoked1', x['revoked'])
  137. self.assertEqual(x['clock'], 314) # incremented
  138. finally:
  139. worker_state.revoked.discard('revoked1')
  140. def test_conf(self):
  141. return
  142. consumer = Consumer(self.app)
  143. panel = self.create_panel(consumer=consumer)
  144. self.app.conf.SOME_KEY6 = 'hello world'
  145. x = panel.handle('dump_conf')
  146. self.assertIn('SOME_KEY6', x)
  147. def test_election(self):
  148. consumer = Consumer(self.app)
  149. panel = self.create_panel(consumer=consumer)
  150. consumer.gossip = Mock()
  151. panel.handle(
  152. 'election', {'id': 'id', 'topic': 'topic', 'action': 'action'},
  153. )
  154. consumer.gossip.election.assert_called_with('id', 'topic', 'action')
  155. def test_heartbeat(self):
  156. consumer = Consumer(self.app)
  157. panel = self.create_panel(consumer=consumer)
  158. consumer.event_dispatcher.enabled = True
  159. panel.handle('heartbeat')
  160. self.assertIn(('worker-heartbeat', ),
  161. consumer.event_dispatcher.send.call_args)
  162. def test_time_limit(self):
  163. panel = self.create_panel(consumer=Mock())
  164. r = panel.handle('time_limit', arguments=dict(
  165. task_name=self.mytask.name, hard=30, soft=10))
  166. self.assertEqual(
  167. (self.mytask.time_limit, self.mytask.soft_time_limit),
  168. (30, 10),
  169. )
  170. self.assertIn('ok', r)
  171. r = panel.handle('time_limit', arguments=dict(
  172. task_name=self.mytask.name, hard=None, soft=None))
  173. self.assertEqual(
  174. (self.mytask.time_limit, self.mytask.soft_time_limit),
  175. (None, None),
  176. )
  177. self.assertIn('ok', r)
  178. r = panel.handle('time_limit', arguments=dict(
  179. task_name='248e8afya9s8dh921eh928', hard=30))
  180. self.assertIn('error', r)
  181. def test_active_queues(self):
  182. import kombu
  183. x = kombu.Consumer(self.app.connection(),
  184. [kombu.Queue('foo', kombu.Exchange('foo'), 'foo'),
  185. kombu.Queue('bar', kombu.Exchange('bar'), 'bar')],
  186. auto_declare=False)
  187. consumer = Mock()
  188. consumer.task_consumer = x
  189. panel = self.create_panel(consumer=consumer)
  190. r = panel.handle('active_queues')
  191. self.assertListEqual(list(sorted(q['name'] for q in r)),
  192. ['bar', 'foo'])
  193. def test_dump_tasks(self):
  194. info = '\n'.join(self.panel.handle('dump_tasks'))
  195. self.assertIn('mytask', info)
  196. self.assertIn('rate_limit=200', info)
  197. def test_stats(self):
  198. prev_count, worker_state.total_count = worker_state.total_count, 100
  199. try:
  200. self.assertDictContainsSubset({'total': 100},
  201. self.panel.handle('stats'))
  202. finally:
  203. worker_state.total_count = prev_count
  204. def test_report(self):
  205. self.panel.handle('report')
  206. def test_active(self):
  207. r = Request({
  208. 'task': self.mytask.name,
  209. 'id': 'do re mi',
  210. 'args': (),
  211. 'kwargs': {},
  212. }, app=self.app)
  213. worker_state.active_requests.add(r)
  214. try:
  215. self.assertTrue(self.panel.handle('dump_active'))
  216. finally:
  217. worker_state.active_requests.discard(r)
  218. def test_pool_grow(self):
  219. class MockPool(object):
  220. def __init__(self, size=1):
  221. self.size = size
  222. def grow(self, n=1):
  223. self.size += n
  224. def shrink(self, n=1):
  225. self.size -= n
  226. @property
  227. def num_processes(self):
  228. return self.size
  229. consumer = Consumer(self.app)
  230. consumer.prefetch_multiplier = 8
  231. consumer.qos = Mock(name='qos')
  232. consumer.pool = MockPool(1)
  233. panel = self.create_panel(consumer=consumer)
  234. panel.handle('pool_grow')
  235. self.assertEqual(consumer.pool.size, 2)
  236. consumer.qos.increment_eventually.assert_called_with(8)
  237. self.assertEqual(consumer.initial_prefetch_count, 16)
  238. panel.handle('pool_shrink')
  239. self.assertEqual(consumer.pool.size, 1)
  240. consumer.qos.decrement_eventually.assert_called_with(8)
  241. self.assertEqual(consumer.initial_prefetch_count, 8)
  242. panel.state.consumer = Mock()
  243. panel.state.consumer.controller = Mock()
  244. sc = panel.state.consumer.controller.autoscaler = Mock()
  245. panel.handle('pool_grow')
  246. self.assertTrue(sc.force_scale_up.called)
  247. panel.handle('pool_shrink')
  248. self.assertTrue(sc.force_scale_down.called)
  249. def test_add__cancel_consumer(self):
  250. class MockConsumer(object):
  251. queues = []
  252. cancelled = []
  253. consuming = False
  254. def add_queue(self, queue):
  255. self.queues.append(queue.name)
  256. def consume(self):
  257. self.consuming = True
  258. def cancel_by_queue(self, queue):
  259. self.cancelled.append(queue)
  260. def consuming_from(self, queue):
  261. return queue in self.queues
  262. consumer = Consumer(self.app)
  263. consumer.task_consumer = MockConsumer()
  264. panel = self.create_panel(consumer=consumer)
  265. panel.handle('add_consumer', {'queue': 'MyQueue'})
  266. self.assertIn('MyQueue', consumer.task_consumer.queues)
  267. self.assertTrue(consumer.task_consumer.consuming)
  268. panel.handle('add_consumer', {'queue': 'MyQueue'})
  269. panel.handle('cancel_consumer', {'queue': 'MyQueue'})
  270. self.assertIn('MyQueue', consumer.task_consumer.cancelled)
  271. def test_revoked(self):
  272. worker_state.revoked.clear()
  273. worker_state.revoked.add('a1')
  274. worker_state.revoked.add('a2')
  275. try:
  276. self.assertEqual(sorted(self.panel.handle('dump_revoked')),
  277. ['a1', 'a2'])
  278. finally:
  279. worker_state.revoked.clear()
  280. def test_dump_schedule(self):
  281. consumer = Consumer(self.app)
  282. panel = self.create_panel(consumer=consumer)
  283. self.assertFalse(panel.handle('dump_schedule'))
  284. r = Request({
  285. 'task': self.mytask.name,
  286. 'id': 'CAFEBABE',
  287. 'args': (),
  288. 'kwargs': {},
  289. }, app=self.app)
  290. consumer.timer.schedule.enter_at(
  291. consumer.timer.Entry(lambda x: x, (r, )),
  292. datetime.now() + timedelta(seconds=10))
  293. consumer.timer.schedule.enter_at(
  294. consumer.timer.Entry(lambda x: x, (object(), )),
  295. datetime.now() + timedelta(seconds=10))
  296. self.assertTrue(panel.handle('dump_schedule'))
  297. def test_dump_reserved(self):
  298. consumer = Consumer(self.app)
  299. worker_state.reserved_requests.add(Request({
  300. 'task': self.mytask.name,
  301. 'id': uuid(),
  302. 'args': (2, 2),
  303. 'kwargs': {},
  304. }, app=self.app))
  305. try:
  306. panel = self.create_panel(consumer=consumer)
  307. response = panel.handle('dump_reserved', {'safe': True})
  308. self.assertDictContainsSubset(
  309. {'name': self.mytask.name,
  310. 'args': (2, 2),
  311. 'kwargs': {},
  312. 'hostname': socket.gethostname()},
  313. response[0],
  314. )
  315. worker_state.reserved_requests.clear()
  316. self.assertFalse(panel.handle('dump_reserved'))
  317. finally:
  318. worker_state.reserved_requests.clear()
  319. def test_rate_limit_invalid_rate_limit_string(self):
  320. e = self.panel.handle('rate_limit', arguments=dict(
  321. task_name='tasks.add', rate_limit='x1240301#%!'))
  322. self.assertIn('Invalid rate limit string', e.get('error'))
  323. def test_rate_limit(self):
  324. class xConsumer(object):
  325. reset = False
  326. def reset_rate_limits(self):
  327. self.reset = True
  328. consumer = xConsumer()
  329. panel = self.create_panel(app=self.app, consumer=consumer)
  330. task = self.app.tasks[self.mytask.name]
  331. panel.handle('rate_limit', arguments=dict(task_name=task.name,
  332. rate_limit='100/m'))
  333. self.assertEqual(task.rate_limit, '100/m')
  334. self.assertTrue(consumer.reset)
  335. consumer.reset = False
  336. panel.handle('rate_limit', arguments=dict(task_name=task.name,
  337. rate_limit=0))
  338. self.assertEqual(task.rate_limit, 0)
  339. self.assertTrue(consumer.reset)
  340. def test_rate_limit_nonexistant_task(self):
  341. self.panel.handle('rate_limit', arguments={
  342. 'task_name': 'xxxx.does.not.exist',
  343. 'rate_limit': '1000/s'})
  344. def test_unexposed_command(self):
  345. with self.assertRaises(KeyError):
  346. self.panel.handle('foo', arguments={})
  347. def test_revoke_with_name(self):
  348. tid = uuid()
  349. m = {'method': 'revoke',
  350. 'destination': hostname,
  351. 'arguments': {'task_id': tid,
  352. 'task_name': self.mytask.name}}
  353. self.panel.handle_message(m, None)
  354. self.assertIn(tid, revoked)
  355. def test_revoke_with_name_not_in_registry(self):
  356. tid = uuid()
  357. m = {'method': 'revoke',
  358. 'destination': hostname,
  359. 'arguments': {'task_id': tid,
  360. 'task_name': 'xxxxxxxxx33333333388888'}}
  361. self.panel.handle_message(m, None)
  362. self.assertIn(tid, revoked)
  363. def test_revoke(self):
  364. tid = uuid()
  365. m = {'method': 'revoke',
  366. 'destination': hostname,
  367. 'arguments': {'task_id': tid}}
  368. self.panel.handle_message(m, None)
  369. self.assertIn(tid, revoked)
  370. m = {'method': 'revoke',
  371. 'destination': 'does.not.exist',
  372. 'arguments': {'task_id': tid + 'xxx'}}
  373. self.panel.handle_message(m, None)
  374. self.assertNotIn(tid + 'xxx', revoked)
  375. def test_revoke_terminate(self):
  376. request = Mock()
  377. request.id = tid = uuid()
  378. worker_state.reserved_requests.add(request)
  379. try:
  380. r = control.revoke(Mock(), tid, terminate=True)
  381. self.assertIn(tid, revoked)
  382. self.assertTrue(request.terminate.call_count)
  383. self.assertIn('terminate:', r['ok'])
  384. # unknown task id only revokes
  385. r = control.revoke(Mock(), uuid(), terminate=True)
  386. self.assertIn('tasks unknown', r['ok'])
  387. finally:
  388. worker_state.reserved_requests.discard(request)
  389. def test_autoscale(self):
  390. self.panel.state.consumer = Mock()
  391. self.panel.state.consumer.controller = Mock()
  392. sc = self.panel.state.consumer.controller.autoscaler = Mock()
  393. sc.update.return_value = 10, 2
  394. m = {'method': 'autoscale',
  395. 'destination': hostname,
  396. 'arguments': {'max': '10', 'min': '2'}}
  397. r = self.panel.handle_message(m, None)
  398. self.assertIn('ok', r)
  399. self.panel.state.consumer.controller.autoscaler = None
  400. r = self.panel.handle_message(m, None)
  401. self.assertIn('error', r)
  402. def test_ping(self):
  403. m = {'method': 'ping',
  404. 'destination': hostname}
  405. r = self.panel.handle_message(m, None)
  406. self.assertEqual(r, {'ok': 'pong'})
  407. def test_shutdown(self):
  408. m = {'method': 'shutdown',
  409. 'destination': hostname}
  410. with self.assertRaises(SystemExit):
  411. self.panel.handle_message(m, None)
  412. def test_panel_reply(self):
  413. replies = []
  414. class _Node(pidbox.Node):
  415. def reply(self, data, exchange, routing_key, **kwargs):
  416. replies.append(data)
  417. panel = _Node(hostname=hostname,
  418. state=self.create_state(consumer=Consumer(self.app)),
  419. handlers=Panel.data,
  420. mailbox=self.app.control.mailbox)
  421. r = panel.dispatch('ping', reply_to={'exchange': 'x',
  422. 'routing_key': 'x'})
  423. self.assertEqual(r, {'ok': 'pong'})
  424. self.assertDictEqual(replies[0], {panel.hostname: {'ok': 'pong'}})
  425. def test_pool_restart(self):
  426. consumer = Consumer(self.app)
  427. consumer.controller = _WC(app=self.app)
  428. consumer.controller.pool.restart = Mock()
  429. panel = self.create_panel(consumer=consumer)
  430. panel.app = self.app
  431. _import = panel.app.loader.import_from_cwd = Mock()
  432. _reload = Mock()
  433. with self.assertRaises(ValueError):
  434. panel.handle('pool_restart', {'reloader': _reload})
  435. self.app.conf.CELERYD_POOL_RESTARTS = True
  436. panel.handle('pool_restart', {'reloader': _reload})
  437. self.assertTrue(consumer.controller.pool.restart.called)
  438. self.assertFalse(_reload.called)
  439. self.assertFalse(_import.called)
  440. def test_pool_restart_import_modules(self):
  441. consumer = Consumer(self.app)
  442. consumer.controller = _WC(app=self.app)
  443. consumer.controller.pool.restart = Mock()
  444. panel = self.create_panel(consumer=consumer)
  445. panel.app = self.app
  446. _import = consumer.controller.app.loader.import_from_cwd = Mock()
  447. _reload = Mock()
  448. self.app.conf.CELERYD_POOL_RESTARTS = True
  449. panel.handle('pool_restart', {'modules': ['foo', 'bar'],
  450. 'reloader': _reload})
  451. self.assertTrue(consumer.controller.pool.restart.called)
  452. self.assertFalse(_reload.called)
  453. self.assertItemsEqual(
  454. [call('bar'), call('foo')],
  455. _import.call_args_list,
  456. )
  457. def test_pool_restart_reload_modules(self):
  458. consumer = Consumer(self.app)
  459. consumer.controller = _WC(app=self.app)
  460. consumer.controller.pool.restart = Mock()
  461. panel = self.create_panel(consumer=consumer)
  462. panel.app = self.app
  463. _import = panel.app.loader.import_from_cwd = Mock()
  464. _reload = Mock()
  465. self.app.conf.CELERYD_POOL_RESTARTS = True
  466. with patch.dict(sys.modules, {'foo': None}):
  467. panel.handle('pool_restart', {'modules': ['foo'],
  468. 'reload': False,
  469. 'reloader': _reload})
  470. self.assertTrue(consumer.controller.pool.restart.called)
  471. self.assertFalse(_reload.called)
  472. self.assertFalse(_import.called)
  473. _import.reset_mock()
  474. _reload.reset_mock()
  475. consumer.controller.pool.restart.reset_mock()
  476. panel.handle('pool_restart', {'modules': ['foo'],
  477. 'reload': True,
  478. 'reloader': _reload})
  479. self.assertTrue(consumer.controller.pool.restart.called)
  480. self.assertTrue(_reload.called)
  481. self.assertFalse(_import.called)