base.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.backends.base
  4. ~~~~~~~~~~~~~~~~~~~~
  5. Result backend base classes.
  6. - :class:`BaseBackend` defines the interface.
  7. - :class:`KeyValueStoreBackend` is a common base class
  8. using K/V semantics like _get and _put.
  9. """
  10. from __future__ import absolute_import
  11. import time
  12. import sys
  13. from datetime import timedelta
  14. from billiard.einfo import ExceptionInfo
  15. from kombu.serialization import (
  16. dumps, loads, prepare_accept_content,
  17. registry as serializer_registry,
  18. )
  19. from kombu.utils.encoding import bytes_to_str, ensure_bytes, from_utf8
  20. from celery import states
  21. from celery.app import current_task
  22. from celery.exceptions import ChordError, TimeoutError, TaskRevokedError
  23. from celery.five import items
  24. from celery.result import result_from_tuple, GroupResult
  25. from celery.utils import timeutils
  26. from celery.utils.functional import LRUCache
  27. from celery.utils.serialization import (
  28. get_pickled_exception,
  29. get_pickleable_exception,
  30. create_exception_cls,
  31. )
  32. __all__ = ['BaseBackend', 'KeyValueStoreBackend', 'DisabledBackend']
  33. EXCEPTION_ABLE_CODECS = frozenset(['pickle', 'yaml'])
  34. PY3 = sys.version_info >= (3, 0)
  35. def unpickle_backend(cls, args, kwargs):
  36. """Return an unpickled backend."""
  37. from celery import current_app
  38. return cls(*args, app=current_app._get_current_object(), **kwargs)
  39. class BaseBackend(object):
  40. READY_STATES = states.READY_STATES
  41. UNREADY_STATES = states.UNREADY_STATES
  42. EXCEPTION_STATES = states.EXCEPTION_STATES
  43. TimeoutError = TimeoutError
  44. #: Time to sleep between polling each individual item
  45. #: in `ResultSet.iterate`. as opposed to the `interval`
  46. #: argument which is for each pass.
  47. subpolling_interval = None
  48. #: If true the backend must implement :meth:`get_many`.
  49. supports_native_join = False
  50. #: If true the backend must automatically expire results.
  51. #: The daily backend_cleanup periodic task will not be triggered
  52. #: in this case.
  53. supports_autoexpire = False
  54. #: Set to true if the backend is peristent by default.
  55. persistent = True
  56. def __init__(self, app, serializer=None,
  57. max_cached_results=None, accept=None, **kwargs):
  58. self.app = app
  59. conf = self.app.conf
  60. self.serializer = serializer or conf.CELERY_RESULT_SERIALIZER
  61. (self.content_type,
  62. self.content_encoding,
  63. self.encoder) = serializer_registry._encoders[self.serializer]
  64. self._cache = LRUCache(
  65. limit=max_cached_results or conf.CELERY_MAX_CACHED_RESULTS,
  66. )
  67. self.accept = prepare_accept_content(
  68. conf.CELERY_ACCEPT_CONTENT if accept is None else accept,
  69. )
  70. def mark_as_started(self, task_id, **meta):
  71. """Mark a task as started"""
  72. return self.store_result(task_id, meta, status=states.STARTED)
  73. def mark_as_done(self, task_id, result, request=None):
  74. """Mark task as successfully executed."""
  75. return self.store_result(task_id, result,
  76. status=states.SUCCESS, request=request)
  77. def mark_as_failure(self, task_id, exc, traceback=None, request=None):
  78. """Mark task as executed with failure. Stores the execption."""
  79. return self.store_result(task_id, exc, status=states.FAILURE,
  80. traceback=traceback, request=request)
  81. def fail_from_current_stack(self, task_id, exc=None):
  82. type_, real_exc, tb = sys.exc_info()
  83. try:
  84. exc = real_exc if exc is None else exc
  85. ei = ExceptionInfo((type_, exc, tb))
  86. self.mark_as_failure(task_id, exc, ei.traceback)
  87. return ei
  88. finally:
  89. del(tb)
  90. def mark_as_retry(self, task_id, exc, traceback=None, request=None):
  91. """Mark task as being retries. Stores the current
  92. exception (if any)."""
  93. return self.store_result(task_id, exc, status=states.RETRY,
  94. traceback=traceback, request=request)
  95. def mark_as_revoked(self, task_id, reason='', request=None):
  96. return self.store_result(task_id, TaskRevokedError(reason),
  97. status=states.REVOKED, traceback=None,
  98. request=request)
  99. def prepare_exception(self, exc):
  100. """Prepare exception for serialization."""
  101. if self.serializer in EXCEPTION_ABLE_CODECS:
  102. return get_pickleable_exception(exc)
  103. return {'exc_type': type(exc).__name__, 'exc_message': str(exc)}
  104. def exception_to_python(self, exc):
  105. """Convert serialized exception to Python exception."""
  106. if self.serializer in EXCEPTION_ABLE_CODECS:
  107. return get_pickled_exception(exc)
  108. return create_exception_cls(
  109. from_utf8(exc['exc_type']), __name__)(exc['exc_message'])
  110. def prepare_value(self, result):
  111. """Prepare value for storage."""
  112. if isinstance(result, GroupResult):
  113. return result.as_tuple()
  114. return result
  115. def encode(self, data):
  116. _, _, payload = dumps(data, serializer=self.serializer)
  117. return payload
  118. def decode(self, payload):
  119. payload = PY3 and payload or str(payload)
  120. return loads(payload,
  121. content_type=self.content_type,
  122. content_encoding=self.content_encoding,
  123. accept=self.accept)
  124. def wait_for(self, task_id, timeout=None, propagate=True, interval=0.5):
  125. """Wait for task and return its result.
  126. If the task raises an exception, this exception
  127. will be re-raised by :func:`wait_for`.
  128. If `timeout` is not :const:`None`, this raises the
  129. :class:`celery.exceptions.TimeoutError` exception if the operation
  130. takes longer than `timeout` seconds.
  131. """
  132. time_elapsed = 0.0
  133. while 1:
  134. status = self.get_status(task_id)
  135. if status == states.SUCCESS:
  136. return self.get_result(task_id)
  137. elif status in states.PROPAGATE_STATES:
  138. result = self.get_result(task_id)
  139. if propagate:
  140. raise result
  141. return result
  142. # avoid hammering the CPU checking status.
  143. time.sleep(interval)
  144. time_elapsed += interval
  145. if timeout and time_elapsed >= timeout:
  146. raise TimeoutError('The operation timed out.')
  147. def prepare_expires(self, value, type=None):
  148. if value is None:
  149. value = self.app.conf.CELERY_TASK_RESULT_EXPIRES
  150. if isinstance(value, timedelta):
  151. value = timeutils.timedelta_seconds(value)
  152. if value is not None and type:
  153. return type(value)
  154. return value
  155. def prepare_persistent(self, enabled=None):
  156. if enabled is not None:
  157. return enabled
  158. p = self.app.conf.CELERY_RESULT_PERSISTENT
  159. return self.persistent if p is None else p
  160. def encode_result(self, result, status):
  161. if status in self.EXCEPTION_STATES and isinstance(result, Exception):
  162. return self.prepare_exception(result)
  163. else:
  164. return self.prepare_value(result)
  165. def is_cached(self, task_id):
  166. return task_id in self._cache
  167. def store_result(self, task_id, result, status,
  168. traceback=None, request=None, **kwargs):
  169. """Update task state and result."""
  170. result = self.encode_result(result, status)
  171. self._store_result(task_id, result, status, traceback,
  172. request=request, **kwargs)
  173. return result
  174. def forget(self, task_id):
  175. self._cache.pop(task_id, None)
  176. self._forget(task_id)
  177. def _forget(self, task_id):
  178. raise NotImplementedError('backend does not implement forget.')
  179. def get_status(self, task_id):
  180. """Get the status of a task."""
  181. return self.get_task_meta(task_id)['status']
  182. def get_traceback(self, task_id):
  183. """Get the traceback for a failed task."""
  184. return self.get_task_meta(task_id).get('traceback')
  185. def get_result(self, task_id):
  186. """Get the result of a task."""
  187. meta = self.get_task_meta(task_id)
  188. if meta['status'] in self.EXCEPTION_STATES:
  189. return self.exception_to_python(meta['result'])
  190. else:
  191. return meta['result']
  192. def get_children(self, task_id):
  193. """Get the list of subtasks sent by a task."""
  194. try:
  195. return self.get_task_meta(task_id)['children']
  196. except KeyError:
  197. pass
  198. def get_task_meta(self, task_id, cache=True):
  199. if cache:
  200. try:
  201. return self._cache[task_id]
  202. except KeyError:
  203. pass
  204. meta = self._get_task_meta_for(task_id)
  205. if cache and meta.get('status') == states.SUCCESS:
  206. self._cache[task_id] = meta
  207. return meta
  208. def reload_task_result(self, task_id):
  209. """Reload task result, even if it has been previously fetched."""
  210. self._cache[task_id] = self.get_task_meta(task_id, cache=False)
  211. def reload_group_result(self, group_id):
  212. """Reload group result, even if it has been previously fetched."""
  213. self._cache[group_id] = self.get_group_meta(group_id, cache=False)
  214. def get_group_meta(self, group_id, cache=True):
  215. if cache:
  216. try:
  217. return self._cache[group_id]
  218. except KeyError:
  219. pass
  220. meta = self._restore_group(group_id)
  221. if cache and meta is not None:
  222. self._cache[group_id] = meta
  223. return meta
  224. def restore_group(self, group_id, cache=True):
  225. """Get the result for a group."""
  226. meta = self.get_group_meta(group_id, cache=cache)
  227. if meta:
  228. return meta['result']
  229. def save_group(self, group_id, result):
  230. """Store the result of an executed group."""
  231. return self._save_group(group_id, result)
  232. def delete_group(self, group_id):
  233. self._cache.pop(group_id, None)
  234. return self._delete_group(group_id)
  235. def cleanup(self):
  236. """Backend cleanup. Is run by
  237. :class:`celery.task.DeleteExpiredTaskMetaTask`."""
  238. pass
  239. def process_cleanup(self):
  240. """Cleanup actions to do at the end of a task worker process."""
  241. pass
  242. def on_task_call(self, producer, task_id):
  243. return {}
  244. def on_chord_part_return(self, task, propagate=False):
  245. pass
  246. def fallback_chord_unlock(self, group_id, body, result=None,
  247. countdown=1, **kwargs):
  248. kwargs['result'] = [r.as_tuple() for r in result]
  249. self.app.tasks['celery.chord_unlock'].apply_async(
  250. (group_id, body, ), kwargs, countdown=countdown,
  251. )
  252. on_chord_apply = fallback_chord_unlock
  253. def current_task_children(self, request=None):
  254. request = request or getattr(current_task(), 'request', None)
  255. if request:
  256. return [r.as_tuple() for r in getattr(request, 'children', [])]
  257. def __reduce__(self, args=(), kwargs={}):
  258. return (unpickle_backend, (self.__class__, args, kwargs))
  259. BaseDictBackend = BaseBackend # XXX compat
  260. class KeyValueStoreBackend(BaseBackend):
  261. task_keyprefix = ensure_bytes('celery-task-meta-')
  262. group_keyprefix = ensure_bytes('celery-taskset-meta-')
  263. chord_keyprefix = ensure_bytes('chord-unlock-')
  264. implements_incr = False
  265. def get(self, key):
  266. raise NotImplementedError('Must implement the get method.')
  267. def mget(self, keys):
  268. raise NotImplementedError('Does not support get_many')
  269. def set(self, key, value):
  270. raise NotImplementedError('Must implement the set method.')
  271. def delete(self, key):
  272. raise NotImplementedError('Must implement the delete method')
  273. def incr(self, key):
  274. raise NotImplementedError('Does not implement incr')
  275. def expire(self, key, value):
  276. pass
  277. def get_key_for_task(self, task_id):
  278. """Get the cache key for a task by id."""
  279. return self.task_keyprefix + ensure_bytes(task_id)
  280. def get_key_for_group(self, group_id):
  281. """Get the cache key for a group by id."""
  282. return self.group_keyprefix + ensure_bytes(group_id)
  283. def get_key_for_chord(self, group_id):
  284. """Get the cache key for the chord waiting on group with given id."""
  285. return self.chord_keyprefix + ensure_bytes(group_id)
  286. def _strip_prefix(self, key):
  287. """Takes bytes, emits string."""
  288. key = ensure_bytes(key)
  289. for prefix in self.task_keyprefix, self.group_keyprefix:
  290. if key.startswith(prefix):
  291. return bytes_to_str(key[len(prefix):])
  292. return bytes_to_str(key)
  293. def _mget_to_results(self, values, keys):
  294. if hasattr(values, 'items'):
  295. # client returns dict so mapping preserved.
  296. return dict((self._strip_prefix(k), self.decode(v))
  297. for k, v in items(values)
  298. if v is not None)
  299. else:
  300. # client returns list so need to recreate mapping.
  301. return dict((bytes_to_str(keys[i]), self.decode(value))
  302. for i, value in enumerate(values)
  303. if value is not None)
  304. def get_many(self, task_ids, timeout=None, interval=0.5,
  305. READY_STATES=states.READY_STATES):
  306. interval = 0.5 if interval is None else interval
  307. ids = task_ids if isinstance(task_ids, set) else set(task_ids)
  308. cached_ids = set()
  309. cache = self._cache
  310. for task_id in ids:
  311. try:
  312. cached = cache[task_id]
  313. except KeyError:
  314. pass
  315. else:
  316. if cached['status'] in READY_STATES:
  317. yield bytes_to_str(task_id), cached
  318. cached_ids.add(task_id)
  319. ids.difference_update(cached_ids)
  320. iterations = 0
  321. while ids:
  322. keys = list(ids)
  323. r = self._mget_to_results(self.mget([self.get_key_for_task(k)
  324. for k in keys]), keys)
  325. cache.update(r)
  326. ids.difference_update(set(bytes_to_str(v) for v in r))
  327. for key, value in items(r):
  328. yield bytes_to_str(key), value
  329. if timeout and iterations * interval >= timeout:
  330. raise TimeoutError('Operation timed out ({0})'.format(timeout))
  331. time.sleep(interval) # don't busy loop.
  332. iterations += 1
  333. def _forget(self, task_id):
  334. self.delete(self.get_key_for_task(task_id))
  335. def _store_result(self, task_id, result, status,
  336. traceback=None, request=None, **kwargs):
  337. meta = {'status': status, 'result': result, 'traceback': traceback,
  338. 'children': self.current_task_children(request)}
  339. self.set(self.get_key_for_task(task_id), self.encode(meta))
  340. return result
  341. def _save_group(self, group_id, result):
  342. self.set(self.get_key_for_group(group_id),
  343. self.encode({'result': result.as_tuple()}))
  344. return result
  345. def _delete_group(self, group_id):
  346. self.delete(self.get_key_for_group(group_id))
  347. def _get_task_meta_for(self, task_id):
  348. """Get task metadata for a task by id."""
  349. meta = self.get(self.get_key_for_task(task_id))
  350. if not meta:
  351. return {'status': states.PENDING, 'result': None}
  352. return self.decode(meta)
  353. def _restore_group(self, group_id):
  354. """Get task metadata for a task by id."""
  355. meta = self.get(self.get_key_for_group(group_id))
  356. # previously this was always pickled, but later this
  357. # was extended to support other serializers, so the
  358. # structure is kind of weird.
  359. if meta:
  360. meta = self.decode(meta)
  361. result = meta['result']
  362. meta['result'] = result_from_tuple(result, self.app)
  363. return meta
  364. def on_chord_apply(self, group_id, body, result=None, **kwargs):
  365. if self.implements_incr:
  366. self.save_group(group_id, self.app.GroupResult(group_id, result))
  367. else:
  368. self.fallback_chord_unlock(group_id, body, result, **kwargs)
  369. def on_chord_part_return(self, task, propagate=None):
  370. if not self.implements_incr:
  371. return
  372. from celery import maybe_signature
  373. from celery.result import GroupResult
  374. app = self.app
  375. if propagate is None:
  376. propagate = self.app.conf.CELERY_CHORD_PROPAGATES
  377. gid = task.request.group
  378. if not gid:
  379. return
  380. key = self.get_key_for_chord(gid)
  381. try:
  382. deps = GroupResult.restore(gid, backend=task.backend)
  383. except Exception as exc:
  384. callback = maybe_signature(task.request.chord, app=self.app)
  385. return app._tasks[callback.task].backend.fail_from_current_stack(
  386. callback.id,
  387. exc=ChordError('Cannot restore group: {0!r}'.format(exc)),
  388. )
  389. if deps is None:
  390. try:
  391. raise ValueError(gid)
  392. except ValueError as exc:
  393. callback = maybe_signature(task.request.chord, app=self.app)
  394. task = app._tasks[callback.task]
  395. return task.backend.fail_from_current_stack(
  396. callback.id,
  397. exc=ChordError('GroupResult {0} no longer exists'.format(
  398. gid,
  399. ))
  400. )
  401. val = self.incr(key)
  402. if val >= len(deps):
  403. callback = maybe_signature(task.request.chord, app=self.app)
  404. j = deps.join_native if deps.supports_native_join else deps.join
  405. try:
  406. ret = j(propagate=propagate)
  407. except Exception as exc:
  408. try:
  409. culprit = next(deps._failed_join_report())
  410. reason = 'Dependency {0.id} raised {1!r}'.format(
  411. culprit, exc,
  412. )
  413. except StopIteration:
  414. reason = repr(exc)
  415. app._tasks[callback.task].backend.fail_from_current_stack(
  416. callback.id, exc=ChordError(reason),
  417. )
  418. else:
  419. try:
  420. callback.delay(ret)
  421. except Exception as exc:
  422. app._tasks[callback.task].backend.fail_from_current_stack(
  423. callback.id,
  424. exc=ChordError('Callback error: {0!r}'.format(exc)),
  425. )
  426. finally:
  427. deps.delete()
  428. self.client.delete(key)
  429. else:
  430. self.expire(key, 86400)
  431. class DisabledBackend(BaseBackend):
  432. _cache = {} # need this attribute to reset cache in tests.
  433. def store_result(self, *args, **kwargs):
  434. pass
  435. def _is_disabled(self, *args, **kwargs):
  436. raise NotImplementedError(
  437. 'No result backend configured. '
  438. 'Please see the documentation for more information.')
  439. wait_for = get_status = get_result = get_traceback = _is_disabled