test_sets.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. from __future__ import absolute_import
  2. import anyjson
  3. import warnings
  4. from mock import Mock, patch
  5. from celery import uuid
  6. from celery.result import TaskSetResult
  7. from celery.task import Task
  8. from celery.canvas import Signature
  9. from celery.tests.tasks.test_result import make_mock_group
  10. from celery.tests.case import AppCase
  11. class SetsCase(AppCase):
  12. def setup(self):
  13. with warnings.catch_warnings(record=True):
  14. from celery.task import sets
  15. self.sets = sets
  16. self.subtask = sets.subtask
  17. self.TaskSet = sets.TaskSet
  18. class MockTask(Task):
  19. app = self.app
  20. name = 'tasks.add'
  21. def run(self, x, y, **kwargs):
  22. return x + y
  23. @classmethod
  24. def apply_async(cls, args, kwargs, **options):
  25. return (args, kwargs, options)
  26. @classmethod
  27. def apply(cls, args, kwargs, **options):
  28. return (args, kwargs, options)
  29. self.MockTask = MockTask
  30. class test_TaskSetResult(AppCase):
  31. def setup(self):
  32. self.size = 10
  33. self.ts = TaskSetResult(uuid(), make_mock_group(self.app, self.size))
  34. def test_total(self):
  35. self.assertEqual(self.ts.total, self.size)
  36. def test_compat_properties(self):
  37. self.assertEqual(self.ts.taskset_id, self.ts.id)
  38. self.ts.taskset_id = 'foo'
  39. self.assertEqual(self.ts.taskset_id, 'foo')
  40. def test_compat_subtasks_kwarg(self):
  41. x = TaskSetResult(uuid(), subtasks=[1, 2, 3])
  42. self.assertEqual(x.results, [1, 2, 3])
  43. def test_itersubtasks(self):
  44. it = self.ts.itersubtasks()
  45. for i, t in enumerate(it):
  46. self.assertEqual(t.get(), i)
  47. class test_App(AppCase):
  48. def test_TaskSet(self):
  49. with warnings.catch_warnings(record=True):
  50. ts = self.app.TaskSet()
  51. self.assertListEqual(ts.tasks, [])
  52. self.assertIs(ts.app, self.app)
  53. class test_subtask(SetsCase):
  54. def test_behaves_like_type(self):
  55. s = self.subtask('tasks.add', (2, 2), {'cache': True},
  56. {'routing_key': 'CPU-bound'})
  57. self.assertDictEqual(self.subtask(s), s)
  58. def test_task_argument_can_be_task_cls(self):
  59. s = self.subtask(self.MockTask, (2, 2))
  60. self.assertEqual(s.task, self.MockTask.name)
  61. def test_apply_async(self):
  62. s = self.MockTask.subtask(
  63. (2, 2), {'cache': True}, {'routing_key': 'CPU-bound'},
  64. )
  65. args, kwargs, options = s.apply_async()
  66. self.assertTupleEqual(args, (2, 2))
  67. self.assertDictEqual(kwargs, {'cache': True})
  68. self.assertDictEqual(options, {'routing_key': 'CPU-bound'})
  69. def test_delay_argmerge(self):
  70. s = self.MockTask.subtask(
  71. (2, ), {'cache': True}, {'routing_key': 'CPU-bound'},
  72. )
  73. args, kwargs, options = s.delay(10, cache=False, other='foo')
  74. self.assertTupleEqual(args, (10, 2))
  75. self.assertDictEqual(kwargs, {'cache': False, 'other': 'foo'})
  76. self.assertDictEqual(options, {'routing_key': 'CPU-bound'})
  77. def test_apply_async_argmerge(self):
  78. s = self.MockTask.subtask(
  79. (2, ), {'cache': True}, {'routing_key': 'CPU-bound'},
  80. )
  81. args, kwargs, options = s.apply_async((10, ),
  82. {'cache': False, 'other': 'foo'},
  83. routing_key='IO-bound',
  84. exchange='fast')
  85. self.assertTupleEqual(args, (10, 2))
  86. self.assertDictEqual(kwargs, {'cache': False, 'other': 'foo'})
  87. self.assertDictEqual(options, {'routing_key': 'IO-bound',
  88. 'exchange': 'fast'})
  89. def test_apply_argmerge(self):
  90. s = self.MockTask.subtask(
  91. (2, ), {'cache': True}, {'routing_key': 'CPU-bound'},
  92. )
  93. args, kwargs, options = s.apply((10, ),
  94. {'cache': False, 'other': 'foo'},
  95. routing_key='IO-bound',
  96. exchange='fast')
  97. self.assertTupleEqual(args, (10, 2))
  98. self.assertDictEqual(kwargs, {'cache': False, 'other': 'foo'})
  99. self.assertDictEqual(
  100. options, {'routing_key': 'IO-bound', 'exchange': 'fast'},
  101. )
  102. def test_is_JSON_serializable(self):
  103. s = self.MockTask.subtask(
  104. (2, ), {'cache': True}, {'routing_key': 'CPU-bound'},
  105. )
  106. s.args = list(s.args) # tuples are not preserved
  107. # but this doesn't matter.
  108. self.assertEqual(s, self.subtask(anyjson.loads(anyjson.dumps(s))))
  109. def test_repr(self):
  110. s = self.MockTask.subtask((2, ), {'cache': True})
  111. self.assertIn('2', repr(s))
  112. self.assertIn('cache=True', repr(s))
  113. def test_reduce(self):
  114. s = self.MockTask.subtask((2, ), {'cache': True})
  115. cls, args = s.__reduce__()
  116. self.assertDictEqual(dict(cls(*args)), dict(s))
  117. class test_TaskSet(SetsCase):
  118. def test_task_arg_can_be_iterable__compat(self):
  119. ts = self.TaskSet([self.MockTask.subtask((i, i))
  120. for i in (2, 4, 8)], app=self.app)
  121. self.assertEqual(len(ts), 3)
  122. def test_respects_ALWAYS_EAGER(self):
  123. app = self.app
  124. class MockTaskSet(self.TaskSet):
  125. applied = 0
  126. def apply(self, *args, **kwargs):
  127. self.applied += 1
  128. ts = MockTaskSet(
  129. [self.MockTask.subtask((i, i)) for i in (2, 4, 8)],
  130. app=self.app,
  131. )
  132. app.conf.CELERY_ALWAYS_EAGER = True
  133. ts.apply_async()
  134. self.assertEqual(ts.applied, 1)
  135. app.conf.CELERY_ALWAYS_EAGER = False
  136. with patch('celery.task.sets.get_current_worker_task') as gwt:
  137. parent = gwt.return_value = Mock()
  138. ts.apply_async()
  139. self.assertTrue(parent.add_trail.called)
  140. def test_apply_async(self):
  141. applied = [0]
  142. class mocksubtask(Signature):
  143. def apply_async(self, *args, **kwargs):
  144. applied[0] += 1
  145. ts = self.TaskSet([mocksubtask(self.MockTask, (i, i))
  146. for i in (2, 4, 8)], app=self.app)
  147. ts.apply_async()
  148. self.assertEqual(applied[0], 3)
  149. class Publisher(object):
  150. def send(self, *args, **kwargs):
  151. pass
  152. ts.apply_async(publisher=Publisher())
  153. # setting current_task
  154. @self.app.task(shared=False)
  155. def xyz():
  156. pass
  157. from celery._state import _task_stack
  158. xyz.push_request()
  159. _task_stack.push(xyz)
  160. try:
  161. ts.apply_async(publisher=Publisher())
  162. finally:
  163. _task_stack.pop()
  164. xyz.pop_request()
  165. def test_apply(self):
  166. applied = [0]
  167. class mocksubtask(Signature):
  168. def apply(self, *args, **kwargs):
  169. applied[0] += 1
  170. ts = self.TaskSet([mocksubtask(self.MockTask, (i, i))
  171. for i in (2, 4, 8)], app=self.app)
  172. ts.apply()
  173. self.assertEqual(applied[0], 3)
  174. def test_set_app(self):
  175. ts = self.TaskSet([], app=self.app)
  176. ts.app = 42
  177. self.assertEqual(ts.app, 42)
  178. def test_set_tasks(self):
  179. ts = self.TaskSet([], app=self.app)
  180. ts.tasks = [1, 2, 3]
  181. self.assertEqual(ts, [1, 2, 3])
  182. def test_set_Publisher(self):
  183. ts = self.TaskSet([], app=self.app)
  184. ts.Publisher = 42
  185. self.assertEqual(ts.Publisher, 42)