batches.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.contrib.batches
  4. ======================
  5. Experimental task class that buffers messages and processes them as a list.
  6. .. warning::
  7. For this to work you have to set
  8. :setting:`CELERYD_PREFETCH_MULTIPLIER` to zero, or some value where
  9. the final multiplied value is higher than ``flush_every``.
  10. In the future we hope to add the ability to direct batching tasks
  11. to a channel with different QoS requirements than the task channel.
  12. **Simple Example**
  13. A click counter that flushes the buffer every 100 messages, and every
  14. seconds. Does not do anything with the data, but can easily be modified
  15. to store it in a database.
  16. .. code-block:: python
  17. # Flush after 100 messages, or 10 seconds.
  18. @app.task(base=Batches, flush_every=100, flush_interval=10)
  19. def count_click(requests):
  20. from collections import Counter
  21. count = Counter(request.kwargs['url'] for request in requests)
  22. for url, count in count.items():
  23. print('>>> Clicks: {0} -> {1}'.format(url, count))
  24. Then you can ask for a click to be counted by doing::
  25. >>> count_click.delay('http://example.com')
  26. **Example returning results**
  27. An interface to the Web of Trust API that flushes the buffer every 100
  28. messages, and every 10 seconds.
  29. .. code-block:: python
  30. import requests
  31. from urlparse import urlparse
  32. from celery.contrib.batches import Batches
  33. wot_api_target = "https://api.mywot.com/0.4/public_link_json"
  34. @app.task(base=Batches, flush_every=100, flush_interval=10)
  35. def wot_api(requests):
  36. sig = lambda url: url
  37. reponses = wot_api_real(
  38. (sig(*request.args, **request.kwargs) for request in requests)
  39. )
  40. # use mark_as_done to manually return response data
  41. for response, request in zip(reponses, requests):
  42. app.backend.mark_as_done(request.id, response)
  43. def wot_api_real(urls):
  44. domains = [urlparse(url).netloc for url in urls]
  45. response = requests.get(
  46. wot_api_target,
  47. params={"hosts": ('/').join(set(domains)) + '/'}
  48. )
  49. return [response.json[domain] for domain in domains]
  50. Using the API is done as follows::
  51. >>> wot_api.delay('http://example.com')
  52. .. note::
  53. If you don't have an ``app`` instance then use the current app proxy
  54. instead::
  55. from celery import current_app
  56. app.backend.mark_as_done(request.id, response)
  57. """
  58. from __future__ import absolute_import
  59. from itertools import count
  60. from celery.task import Task
  61. from celery.five import Empty, Queue
  62. from celery.utils.log import get_logger
  63. from celery.worker.job import Request
  64. from celery.utils import noop
  65. __all__ = ['Batches']
  66. logger = get_logger(__name__)
  67. def consume_queue(queue):
  68. """Iterator yielding all immediately available items in a
  69. :class:`Queue.Queue`.
  70. The iterator stops as soon as the queue raises :exc:`Queue.Empty`.
  71. *Examples*
  72. >>> q = Queue()
  73. >>> map(q.put, range(4))
  74. >>> list(consume_queue(q))
  75. [0, 1, 2, 3]
  76. >>> list(consume_queue(q))
  77. []
  78. """
  79. get = queue.get_nowait
  80. while 1:
  81. try:
  82. yield get()
  83. except Empty:
  84. break
  85. def apply_batches_task(task, args, loglevel, logfile):
  86. task.push_request(loglevel=loglevel, logfile=logfile)
  87. try:
  88. result = task(*args)
  89. except Exception as exc:
  90. result = None
  91. logger.error('Error: %r', exc, exc_info=True)
  92. finally:
  93. task.pop_request()
  94. return result
  95. class SimpleRequest(object):
  96. """Pickleable request."""
  97. #: task id
  98. id = None
  99. #: task name
  100. name = None
  101. #: positional arguments
  102. args = ()
  103. #: keyword arguments
  104. kwargs = {}
  105. #: message delivery information.
  106. delivery_info = None
  107. #: worker node name
  108. hostname = None
  109. def __init__(self, id, name, args, kwargs, delivery_info, hostname):
  110. self.id = id
  111. self.name = name
  112. self.args = args
  113. self.kwargs = kwargs
  114. self.delivery_info = delivery_info
  115. self.hostname = hostname
  116. @classmethod
  117. def from_request(cls, request):
  118. return cls(request.id, request.name, request.args,
  119. request.kwargs, request.delivery_info, request.hostname)
  120. class Batches(Task):
  121. abstract = True
  122. #: Maximum number of message in buffer.
  123. flush_every = 10
  124. #: Timeout in seconds before buffer is flushed anyway.
  125. flush_interval = 30
  126. def __init__(self):
  127. self._buffer = Queue()
  128. self._count = count(1)
  129. self._tref = None
  130. self._pool = None
  131. def run(self, requests):
  132. raise NotImplementedError('must implement run(requests)')
  133. def Strategy(self, task, app, consumer):
  134. self._pool = consumer.pool
  135. hostname = consumer.hostname
  136. eventer = consumer.event_dispatcher
  137. Req = Request
  138. connection_errors = consumer.connection_errors
  139. timer = consumer.timer
  140. put_buffer = self._buffer.put
  141. flush_buffer = self._do_flush
  142. def task_message_handler(message, body, ack, reject, callbacks, **kw):
  143. request = Req(body, on_ack=ack, app=app, hostname=hostname,
  144. events=eventer, task=task,
  145. connection_errors=connection_errors,
  146. delivery_info=message.delivery_info)
  147. put_buffer(request)
  148. if self._tref is None: # first request starts flush timer.
  149. self._tref = timer.call_repeatedly(
  150. self.flush_interval, flush_buffer,
  151. )
  152. if not next(self._count) % self.flush_every:
  153. flush_buffer()
  154. return task_message_handler
  155. def flush(self, requests):
  156. return self.apply_buffer(requests, ([SimpleRequest.from_request(r)
  157. for r in requests], ))
  158. def _do_flush(self):
  159. logger.debug('Batches: Wake-up to flush buffer...')
  160. requests = None
  161. if self._buffer.qsize():
  162. requests = list(consume_queue(self._buffer))
  163. if requests:
  164. logger.debug('Batches: Buffer complete: %s', len(requests))
  165. self.flush(requests)
  166. if not requests:
  167. logger.debug('Batches: Cancelling timer: Nothing in buffer.')
  168. self._tref.cancel() # cancel timer.
  169. self._tref = None
  170. def apply_buffer(self, requests, args=(), kwargs={}):
  171. acks_late = [], []
  172. [acks_late[r.task.acks_late].append(r) for r in requests]
  173. assert requests and (acks_late[True] or acks_late[False])
  174. def on_accepted(pid, time_accepted):
  175. [req.acknowledge() for req in acks_late[False]]
  176. def on_return(result):
  177. [req.acknowledge() for req in acks_late[True]]
  178. return self._pool.apply_async(
  179. apply_batches_task,
  180. (self, args, 0, None),
  181. accept_callback=on_accepted,
  182. callback=acks_late[True] and on_return or noop,
  183. )