registry.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import pickle
  2. from huey.exceptions import QueueException
  3. class TaskRegistry(object):
  4. """
  5. A simple Registry used to track subclasses of :class:`QueueTask` - the
  6. purpose of this registry is to allow translation from queue messages to
  7. task classes, and vice-versa.
  8. """
  9. _ignore = ['QueueTask', 'PeriodicQueueTask', 'NewBase']
  10. def __init__(self):
  11. self._registry = {}
  12. self._periodic_tasks = []
  13. def task_to_string(self, task):
  14. return '%s' % (task.__name__)
  15. def register(self, task_class):
  16. klass_str = self.task_to_string(task_class)
  17. if klass_str in self._ignore:
  18. return
  19. if klass_str not in self._registry:
  20. self._registry[klass_str] = task_class
  21. # store an instance in a separate list of periodic tasks
  22. if hasattr(task_class, 'validate_datetime'):
  23. self._periodic_tasks.append(task_class)
  24. def unregister(self, task_class):
  25. klass_str = self.task_to_string(task_class)
  26. if klass_str in self._registry:
  27. del(self._registry[klass_str])
  28. for task in self._periodic_tasks:
  29. if isinstance(task, task_class):
  30. self._periodic_tasks.remove(task)
  31. def __contains__(self, klass_str):
  32. return klass_str in self._registry
  33. def get_message_for_task(self, task):
  34. """Convert a task object to a message for storage in the queue"""
  35. data = task.get_data()
  36. if data and isinstance(data, tuple) and len(data) == 2:
  37. args, kwargs = data
  38. if isinstance(kwargs, dict) and 'task' in kwargs:
  39. kwargs.pop('task')
  40. data = (args, kwargs)
  41. if task.on_complete is not None:
  42. on_complete = self.get_message_for_task(task.on_complete)
  43. else:
  44. on_complete = None
  45. return pickle.dumps((
  46. task.task_id,
  47. self.task_to_string(type(task)),
  48. task.execute_time,
  49. task.retries,
  50. task.retry_delay,
  51. data,
  52. on_complete))
  53. def get_task_class(self, klass_str):
  54. klass = self._registry.get(klass_str)
  55. if not klass:
  56. raise QueueException('%s not found in TaskRegistry' % klass_str)
  57. return klass
  58. def get_task_for_message(self, msg):
  59. """Convert a message from the queue into a task"""
  60. # parse out the pieces from the enqueued message
  61. raw = pickle.loads(msg)
  62. if len(raw) == 7:
  63. task_id, klass_str, ex_time, retries, delay, data, oc_raw = raw
  64. elif len(raw) == 6:
  65. task_id, klass_str, ex_time, retries, delay, data = raw
  66. oc_raw = None
  67. klass = self.get_task_class(klass_str)
  68. on_complete = self.get_task_for_message(oc_raw) if oc_raw else None
  69. return klass(data, task_id, ex_time, retries, delay, on_complete)
  70. def get_periodic_tasks(self):
  71. return [task_class() for task_class in self._periodic_tasks]
  72. registry = TaskRegistry()