from __future__ import absolute_import import anyjson import warnings from mock import Mock, patch from celery import uuid from celery.result import TaskSetResult from celery.task import Task from celery.canvas import Signature from celery.tests.tasks.test_result import make_mock_group from celery.tests.case import AppCase class SetsCase(AppCase): def setup(self): with warnings.catch_warnings(record=True): from celery.task import sets self.sets = sets self.subtask = sets.subtask self.TaskSet = sets.TaskSet class MockTask(Task): app = self.app name = 'tasks.add' def run(self, x, y, **kwargs): return x + y @classmethod def apply_async(cls, args, kwargs, **options): return (args, kwargs, options) @classmethod def apply(cls, args, kwargs, **options): return (args, kwargs, options) self.MockTask = MockTask class test_TaskSetResult(AppCase): def setup(self): self.size = 10 self.ts = TaskSetResult(uuid(), make_mock_group(self.app, self.size)) def test_total(self): self.assertEqual(self.ts.total, self.size) def test_compat_properties(self): self.assertEqual(self.ts.taskset_id, self.ts.id) self.ts.taskset_id = 'foo' self.assertEqual(self.ts.taskset_id, 'foo') def test_compat_subtasks_kwarg(self): x = TaskSetResult(uuid(), subtasks=[1, 2, 3]) self.assertEqual(x.results, [1, 2, 3]) def test_itersubtasks(self): it = self.ts.itersubtasks() for i, t in enumerate(it): self.assertEqual(t.get(), i) class test_App(AppCase): def test_TaskSet(self): with warnings.catch_warnings(record=True): ts = self.app.TaskSet() self.assertListEqual(ts.tasks, []) self.assertIs(ts.app, self.app) class test_subtask(SetsCase): def test_behaves_like_type(self): s = self.subtask('tasks.add', (2, 2), {'cache': True}, {'routing_key': 'CPU-bound'}) self.assertDictEqual(self.subtask(s), s) def test_task_argument_can_be_task_cls(self): s = self.subtask(self.MockTask, (2, 2)) self.assertEqual(s.task, self.MockTask.name) def test_apply_async(self): s = self.MockTask.subtask( (2, 2), {'cache': True}, {'routing_key': 'CPU-bound'}, ) args, kwargs, options = s.apply_async() self.assertTupleEqual(args, (2, 2)) self.assertDictEqual(kwargs, {'cache': True}) self.assertDictEqual(options, {'routing_key': 'CPU-bound'}) def test_delay_argmerge(self): s = self.MockTask.subtask( (2, ), {'cache': True}, {'routing_key': 'CPU-bound'}, ) args, kwargs, options = s.delay(10, cache=False, other='foo') self.assertTupleEqual(args, (10, 2)) self.assertDictEqual(kwargs, {'cache': False, 'other': 'foo'}) self.assertDictEqual(options, {'routing_key': 'CPU-bound'}) def test_apply_async_argmerge(self): s = self.MockTask.subtask( (2, ), {'cache': True}, {'routing_key': 'CPU-bound'}, ) args, kwargs, options = s.apply_async((10, ), {'cache': False, 'other': 'foo'}, routing_key='IO-bound', exchange='fast') self.assertTupleEqual(args, (10, 2)) self.assertDictEqual(kwargs, {'cache': False, 'other': 'foo'}) self.assertDictEqual(options, {'routing_key': 'IO-bound', 'exchange': 'fast'}) def test_apply_argmerge(self): s = self.MockTask.subtask( (2, ), {'cache': True}, {'routing_key': 'CPU-bound'}, ) args, kwargs, options = s.apply((10, ), {'cache': False, 'other': 'foo'}, routing_key='IO-bound', exchange='fast') self.assertTupleEqual(args, (10, 2)) self.assertDictEqual(kwargs, {'cache': False, 'other': 'foo'}) self.assertDictEqual( options, {'routing_key': 'IO-bound', 'exchange': 'fast'}, ) def test_is_JSON_serializable(self): s = self.MockTask.subtask( (2, ), {'cache': True}, {'routing_key': 'CPU-bound'}, ) s.args = list(s.args) # tuples are not preserved # but this doesn't matter. self.assertEqual(s, self.subtask(anyjson.loads(anyjson.dumps(s)))) def test_repr(self): s = self.MockTask.subtask((2, ), {'cache': True}) self.assertIn('2', repr(s)) self.assertIn('cache=True', repr(s)) def test_reduce(self): s = self.MockTask.subtask((2, ), {'cache': True}) cls, args = s.__reduce__() self.assertDictEqual(dict(cls(*args)), dict(s)) class test_TaskSet(SetsCase): def test_task_arg_can_be_iterable__compat(self): ts = self.TaskSet([self.MockTask.subtask((i, i)) for i in (2, 4, 8)], app=self.app) self.assertEqual(len(ts), 3) def test_respects_ALWAYS_EAGER(self): app = self.app class MockTaskSet(self.TaskSet): applied = 0 def apply(self, *args, **kwargs): self.applied += 1 ts = MockTaskSet( [self.MockTask.subtask((i, i)) for i in (2, 4, 8)], app=self.app, ) app.conf.CELERY_ALWAYS_EAGER = True ts.apply_async() self.assertEqual(ts.applied, 1) app.conf.CELERY_ALWAYS_EAGER = False with patch('celery.task.sets.get_current_worker_task') as gwt: parent = gwt.return_value = Mock() ts.apply_async() self.assertTrue(parent.add_trail.called) def test_apply_async(self): applied = [0] class mocksubtask(Signature): def apply_async(self, *args, **kwargs): applied[0] += 1 ts = self.TaskSet([mocksubtask(self.MockTask, (i, i)) for i in (2, 4, 8)], app=self.app) ts.apply_async() self.assertEqual(applied[0], 3) class Publisher(object): def send(self, *args, **kwargs): pass ts.apply_async(publisher=Publisher()) # setting current_task @self.app.task(shared=False) def xyz(): pass from celery._state import _task_stack xyz.push_request() _task_stack.push(xyz) try: ts.apply_async(publisher=Publisher()) finally: _task_stack.pop() xyz.pop_request() def test_apply(self): applied = [0] class mocksubtask(Signature): def apply(self, *args, **kwargs): applied[0] += 1 ts = self.TaskSet([mocksubtask(self.MockTask, (i, i)) for i in (2, 4, 8)], app=self.app) ts.apply() self.assertEqual(applied[0], 3) def test_set_app(self): ts = self.TaskSet([], app=self.app) ts.app = 42 self.assertEqual(ts.app, 42) def test_set_tasks(self): ts = self.TaskSet([], app=self.app) ts.tasks = [1, 2, 3] self.assertEqual(ts, [1, 2, 3]) def test_set_Publisher(self): ts = self.TaskSet([], app=self.app) ts.Publisher = 42 self.assertEqual(ts.Publisher, 42)