from __future__ import absolute_import from collections import Callable from datetime import datetime, timedelta from mock import patch from kombu import Queue from celery import Task from celery.exceptions import Retry from celery.five import items, range, string_t from celery.result import EagerResult from celery.utils import uuid from celery.utils.timeutils import parse_iso8601 from celery.tests.case import AppCase, depends_on_current_app def return_True(*args, **kwargs): # Task run functions can't be closures/lambdas, as they're pickled. return True def raise_exception(self, **kwargs): raise Exception('%s error' % self.__class__) class MockApplyTask(Task): abstract = True applied = 0 def run(self, x, y): return x * y def apply_async(self, *args, **kwargs): self.applied += 1 class TasksCase(AppCase): def setup(self): self.mytask = self.app.task(shared=False)(return_True) @self.app.task(bind=True, count=0, shared=False) def increment_counter(self, increment_by=1): self.count += increment_by or 1 return self.count self.increment_counter = increment_counter @self.app.task(shared=False) def raising(): raise KeyError('foo') self.raising = raising @self.app.task(bind=True, max_retries=3, iterations=0, shared=False) def retry_task(self, arg1, arg2, kwarg=1, max_retries=None, care=True): self.iterations += 1 rmax = self.max_retries if max_retries is None else max_retries assert repr(self.request) retries = self.request.retries if care and retries >= rmax: return arg1 else: raise self.retry(countdown=0, max_retries=rmax) self.retry_task = retry_task @self.app.task(bind=True, max_retries=3, iterations=0, shared=False) def retry_task_noargs(self, **kwargs): self.iterations += 1 if self.request.retries >= 3: return 42 else: raise self.retry(countdown=0) self.retry_task_noargs = retry_task_noargs @self.app.task(bind=True, max_retries=3, iterations=0, base=MockApplyTask, shared=False) def retry_task_mockapply(self, arg1, arg2, kwarg=1): self.iterations += 1 retries = self.request.retries if retries >= 3: return arg1 raise self.retry(countdown=0) self.retry_task_mockapply = retry_task_mockapply @self.app.task(bind=True, max_retries=3, iterations=0, shared=False) def retry_task_customexc(self, arg1, arg2, kwarg=1, **kwargs): self.iterations += 1 retries = self.request.retries if retries >= 3: return arg1 + kwarg else: try: raise MyCustomException('Elaine Marie Benes') except MyCustomException as exc: kwargs.update(kwarg=kwarg) raise self.retry(countdown=0, exc=exc) self.retry_task_customexc = retry_task_customexc class MyCustomException(Exception): """Random custom exception.""" class test_task_retries(TasksCase): def test_retry(self): self.retry_task.max_retries = 3 self.retry_task.iterations = 0 self.retry_task.apply([0xFF, 0xFFFF]) self.assertEqual(self.retry_task.iterations, 4) self.retry_task.max_retries = 3 self.retry_task.iterations = 0 self.retry_task.apply([0xFF, 0xFFFF], {'max_retries': 10}) self.assertEqual(self.retry_task.iterations, 11) def test_retry_no_args(self): self.retry_task_noargs.max_retries = 3 self.retry_task_noargs.iterations = 0 self.retry_task_noargs.apply(propagate=True).get() self.assertEqual(self.retry_task_noargs.iterations, 4) def test_retry_kwargs_can_be_empty(self): self.retry_task_mockapply.push_request() try: with self.assertRaises(Retry): self.retry_task_mockapply.retry(args=[4, 4], kwargs=None) finally: self.retry_task_mockapply.pop_request() def test_retry_not_eager(self): self.retry_task_mockapply.push_request() try: self.retry_task_mockapply.request.called_directly = False exc = Exception('baz') try: self.retry_task_mockapply.retry( args=[4, 4], kwargs={'task_retries': 0}, exc=exc, throw=False, ) self.assertTrue(self.retry_task_mockapply.applied) finally: self.retry_task_mockapply.applied = 0 try: with self.assertRaises(Retry): self.retry_task_mockapply.retry( args=[4, 4], kwargs={'task_retries': 0}, exc=exc, throw=True) self.assertTrue(self.retry_task_mockapply.applied) finally: self.retry_task_mockapply.applied = 0 finally: self.retry_task_mockapply.pop_request() def test_retry_with_kwargs(self): self.retry_task_customexc.max_retries = 3 self.retry_task_customexc.iterations = 0 self.retry_task_customexc.apply([0xFF, 0xFFFF], {'kwarg': 0xF}) self.assertEqual(self.retry_task_customexc.iterations, 4) def test_retry_with_custom_exception(self): self.retry_task_customexc.max_retries = 2 self.retry_task_customexc.iterations = 0 result = self.retry_task_customexc.apply( [0xFF, 0xFFFF], {'kwarg': 0xF}, ) with self.assertRaises(MyCustomException): result.get() self.assertEqual(self.retry_task_customexc.iterations, 3) def test_max_retries_exceeded(self): self.retry_task.max_retries = 2 self.retry_task.iterations = 0 result = self.retry_task.apply([0xFF, 0xFFFF], {'care': False}) with self.assertRaises(self.retry_task.MaxRetriesExceededError): result.get() self.assertEqual(self.retry_task.iterations, 3) self.retry_task.max_retries = 1 self.retry_task.iterations = 0 result = self.retry_task.apply([0xFF, 0xFFFF], {'care': False}) with self.assertRaises(self.retry_task.MaxRetriesExceededError): result.get() self.assertEqual(self.retry_task.iterations, 2) class test_canvas_utils(TasksCase): def test_si(self): self.assertTrue(self.retry_task.si()) self.assertTrue(self.retry_task.si().immutable) def test_chunks(self): self.assertTrue(self.retry_task.chunks(range(100), 10)) def test_map(self): self.assertTrue(self.retry_task.map(range(100))) def test_starmap(self): self.assertTrue(self.retry_task.starmap(range(100))) def test_on_success(self): self.retry_task.on_success(1, 1, (), {}) class test_tasks(TasksCase): def now(self): return self.app.now() @depends_on_current_app def test_unpickle_task(self): import pickle @self.app.task(shared=True) def xxx(): pass self.assertIs(pickle.loads(pickle.dumps(xxx)), xxx.app.tasks[xxx.name]) def test_AsyncResult(self): task_id = uuid() result = self.retry_task.AsyncResult(task_id) self.assertEqual(result.backend, self.retry_task.backend) self.assertEqual(result.id, task_id) def assertNextTaskDataEqual(self, consumer, presult, task_name, test_eta=False, test_expires=False, **kwargs): next_task = consumer.queues[0].get(accept=['pickle']) task_data = next_task.decode() self.assertEqual(task_data['id'], presult.id) self.assertEqual(task_data['task'], task_name) task_kwargs = task_data.get('kwargs', {}) if test_eta: self.assertIsInstance(task_data.get('eta'), string_t) to_datetime = parse_iso8601(task_data.get('eta')) self.assertIsInstance(to_datetime, datetime) if test_expires: self.assertIsInstance(task_data.get('expires'), string_t) to_datetime = parse_iso8601(task_data.get('expires')) self.assertIsInstance(to_datetime, datetime) for arg_name, arg_value in items(kwargs): self.assertEqual(task_kwargs.get(arg_name), arg_value) def test_incomplete_task_cls(self): class IncompleteTask(Task): app = self.app name = 'c.unittest.t.itask' with self.assertRaises(NotImplementedError): IncompleteTask().run() def test_task_kwargs_must_be_dictionary(self): with self.assertRaises(ValueError): self.increment_counter.apply_async([], 'str') def test_task_args_must_be_list(self): with self.assertRaises(ValueError): self.increment_counter.apply_async('str', {}) def test_regular_task(self): self.assertIsInstance(self.mytask, Task) self.assertTrue(self.mytask.run()) self.assertTrue(isinstance(self.mytask, Callable), 'Task class is callable()') self.assertTrue(self.mytask(), 'Task class runs run() when called') with self.app.connection_or_acquire() as conn: consumer = self.app.amqp.TaskConsumer(conn) with self.assertRaises(NotImplementedError): consumer.receive('foo', 'foo') consumer.purge() self.assertIsNone(consumer.queues[0].get()) self.app.amqp.TaskConsumer(conn, queues=[Queue('foo')]) # Without arguments. presult = self.mytask.delay() self.assertNextTaskDataEqual(consumer, presult, self.mytask.name) # With arguments. presult2 = self.mytask.apply_async( kwargs=dict(name='George Costanza'), ) self.assertNextTaskDataEqual( consumer, presult2, self.mytask.name, name='George Costanza', ) # send_task sresult = self.app.send_task(self.mytask.name, kwargs=dict(name='Elaine M. Benes')) self.assertNextTaskDataEqual( consumer, sresult, self.mytask.name, name='Elaine M. Benes', ) # With eta. presult2 = self.mytask.apply_async( kwargs=dict(name='George Costanza'), eta=self.now() + timedelta(days=1), expires=self.now() + timedelta(days=2), ) self.assertNextTaskDataEqual( consumer, presult2, self.mytask.name, name='George Costanza', test_eta=True, test_expires=True, ) # With countdown. presult2 = self.mytask.apply_async( kwargs=dict(name='George Costanza'), countdown=10, expires=12, ) self.assertNextTaskDataEqual( consumer, presult2, self.mytask.name, name='George Costanza', test_eta=True, test_expires=True, ) # Discarding all tasks. consumer.purge() self.mytask.apply_async() self.assertEqual(consumer.purge(), 1) self.assertIsNone(consumer.queues[0].get()) self.assertFalse(presult.successful()) self.mytask.backend.mark_as_done(presult.id, result=None) self.assertTrue(presult.successful()) def test_repr_v2_compat(self): self.mytask.__v2_compat__ = True self.assertIn('v2 compatible', repr(self.mytask)) def test_apply_with_self(self): @self.app.task(__self__=42, shared=False) def tawself(self): return self self.assertEqual(tawself.apply().get(), 42) self.assertEqual(tawself(), 42) def test_context_get(self): self.mytask.push_request() try: request = self.mytask.request request.foo = 32 self.assertEqual(request.get('foo'), 32) self.assertEqual(request.get('bar', 36), 36) request.clear() finally: self.mytask.pop_request() def test_task_class_repr(self): self.assertIn('class Task of', repr(self.mytask.app.Task)) self.mytask.app.Task._app = None self.assertIn('unbound', repr(self.mytask.app.Task, )) def test_bind_no_magic_kwargs(self): self.mytask.accept_magic_kwargs = None self.mytask.bind(self.mytask.app) def test_annotate(self): with patch('celery.app.task.resolve_all_annotations') as anno: anno.return_value = [{'FOO': 'BAR'}] @self.app.task(shared=False) def task(): pass task.annotate() self.assertEqual(task.FOO, 'BAR') def test_after_return(self): self.mytask.push_request() try: self.mytask.request.chord = self.mytask.s() self.mytask.after_return('SUCCESS', 1.0, 'foobar', (), {}, None) self.mytask.request.clear() finally: self.mytask.pop_request() def test_send_task_sent_event(self): with self.app.connection() as conn: self.app.conf.CELERY_SEND_TASK_SENT_EVENT = True self.assertTrue(self.app.amqp.TaskProducer(conn).send_sent_event) def test_update_state(self): @self.app.task(shared=False) def yyy(): pass yyy.push_request() try: tid = uuid() yyy.update_state(tid, 'FROBULATING', {'fooz': 'baaz'}) self.assertEqual(yyy.AsyncResult(tid).status, 'FROBULATING') self.assertDictEqual(yyy.AsyncResult(tid).result, {'fooz': 'baaz'}) yyy.request.id = tid yyy.update_state(state='FROBUZATING', meta={'fooz': 'baaz'}) self.assertEqual(yyy.AsyncResult(tid).status, 'FROBUZATING') self.assertDictEqual(yyy.AsyncResult(tid).result, {'fooz': 'baaz'}) finally: yyy.pop_request() def test_repr(self): @self.app.task(shared=False) def task_test_repr(): pass self.assertIn('task_test_repr', repr(task_test_repr)) def test_has___name__(self): @self.app.task(shared=False) def yyy2(): pass self.assertTrue(yyy2.__name__) class test_apply_task(TasksCase): def test_apply_throw(self): with self.assertRaises(KeyError): self.raising.apply(throw=True) def test_apply_with_CELERY_EAGER_PROPAGATES_EXCEPTIONS(self): self.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = True with self.assertRaises(KeyError): self.raising.apply() def test_apply(self): self.increment_counter.count = 0 e = self.increment_counter.apply() self.assertIsInstance(e, EagerResult) self.assertEqual(e.get(), 1) e = self.increment_counter.apply(args=[1]) self.assertEqual(e.get(), 2) e = self.increment_counter.apply(kwargs={'increment_by': 4}) self.assertEqual(e.get(), 6) self.assertTrue(e.successful()) self.assertTrue(e.ready()) self.assertTrue(repr(e).startswith('