from __future__ import absolute_import from celery.contrib.abortable import AbortableTask, AbortableAsyncResult from celery.tests.case import AppCase class test_AbortableTask(AppCase): def setup(self): @self.app.task(base=AbortableTask, shared=False) def abortable(): return True self.abortable = abortable def test_async_result_is_abortable(self): result = self.abortable.apply_async() tid = result.id self.assertIsInstance( self.abortable.AsyncResult(tid), AbortableAsyncResult, ) def test_is_not_aborted(self): self.abortable.push_request() try: result = self.abortable.apply_async() tid = result.id self.assertFalse(self.abortable.is_aborted(task_id=tid)) finally: self.abortable.pop_request() def test_is_aborted_not_abort_result(self): self.abortable.AsyncResult = self.app.AsyncResult self.abortable.push_request() try: self.abortable.request.id = 'foo' self.assertFalse(self.abortable.is_aborted()) finally: self.abortable.pop_request() def test_abort_yields_aborted(self): self.abortable.push_request() try: result = self.abortable.apply_async() result.abort() tid = result.id self.assertTrue(self.abortable.is_aborted(task_id=tid)) finally: self.abortable.pop_request()